From 6a672821f77eb4dfc418d8dab3fc29a4d7b73100 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 13:49:30 -0700 Subject: [PATCH 01/13] feat: Support PipelineTrainer with dedicated LocalBackend --- src/art/local/backend.py | 72 ++++- src/art/pipeline_trainer/trainer.py | 75 ++++- src/art/test/test_step_skipping.py | 2 +- src/art/unsloth/service.py | 63 +++- src/art/unsloth/train.py | 9 +- .../benchmarking/pull_model_trajectories.py | 5 +- .../test_pipeline_trainer_local_backend.py | 295 ++++++++++++++++++ 7 files changed, 495 insertions(+), 26 deletions(-) create mode 100644 tests/unit/test_pipeline_trainer_local_backend.py diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 5baf200f3..dfea5c124 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -159,19 +159,59 @@ def _allocated_gpu_count(self, model: Model) -> int: def __enter__(self) -> Self: return self + async def __aenter__(self) -> Self: + return self + def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: + try: + asyncio.get_running_loop() + except RuntimeError: + running_loop = False + else: + running_loop = True + + if running_loop or any( + getattr(service, "aclose", None) is not None + for service in self._services.values() + ): + warnings.warn( + "LocalBackend used as a sync context manager. Cleanup uses the " + "best-effort sync shutdown path and cannot await service " + "teardown safely here; use `async with LocalBackend(...)` or " + "`await backend.close()` instead.", + RuntimeWarning, + stacklevel=2, + ) self._close() + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.close() + async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - self._close() + for _, service in self._services.items(): + # Keep this logic aligned with _close(), but avoid double-closing + # services that expose an awaited aclose() path. + aclose = getattr(service, "aclose", None) + if aclose is not None: + await aclose() + else: + close = getattr(service, "close", None) + if close is not None: + close() + close_proxy(service) def _close(self) -> None: for _, service in self._services.items(): @@ -219,21 +259,31 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str: If None, returns name for latest checkpoint (step 0 initially). """ + def _served_step() -> int | None: + if not isinstance(model, TrainableModel): + return None + if model.name not in self._services: + return None + from ..dev.validate import is_dedicated_mode + + if not is_dedicated_mode( + model._internal_config or dev.InternalModelConfig() + ): + return None + loaded_step = getattr(self._services[model.name], "_latest_step", None) + return loaded_step if isinstance(loaded_step, int) else None + # For LocalBackend, vLLM always serves LoRA adapters with @step suffix # Default to step 0 when not specified (the initial checkpoint created at registration) if step is not None: actual_step = step - elif model.name in self._services and self._in_process: - # In dedicated mode the service tracks which adapter vLLM has - # actually loaded. Reading the filesystem would race: the - # checkpoint directory appears before the HTTP reload completes. - svc = self._services[model.name] - loaded_step = getattr(svc, "_latest_step", None) - actual_step = ( - loaded_step if loaded_step is not None else self.__get_step(model) - ) else: - actual_step = self.__get_step(model) + # In dedicated mode the service tracks which adapter vLLM has + # actually loaded. Reading the filesystem would race: the checkpoint + # directory appears before the HTTP reload completes. + actual_step = _served_step() + if actual_step is None: + actual_step = self.__get_step(model) name = f"{model.name}@{actual_step}" logger.debug( f"[BACKEND] _model_inference_name: step_arg={step} " diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 1458d153c..c2ed80f92 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -154,6 +154,7 @@ def __init__( total_scenarios=total_scenarios, num_workers=num_rollout_workers, ) + self._validate_backend_support() async def train(self, *, handle_signals: bool = True) -> None: """Run the training pipeline over the configured scenario iterator.""" @@ -277,6 +278,72 @@ async def _notify_policy() -> None: except asyncio.QueueFull: loop.create_task(self._output_queue.put(None)) + def _is_local_backend(self) -> bool: + from art.local.backend import LocalBackend + + return isinstance(self.backend, LocalBackend) + + def _local_backend_is_dedicated(self) -> bool: + if not isinstance(self.model, art.TrainableModel): + return False + from art.dev.validate import is_dedicated_mode + + return is_dedicated_mode( + self.model._internal_config or art.dev.InternalModelConfig() + ) + + def _validate_backend_support(self) -> None: + if not self._is_local_backend(): + return + if self._local_backend_is_dedicated(): + self._validate_local_backend_train_config() + return + raise ValueError( + "PipelineTrainer only supports LocalBackend in dedicated mode. " + "Shared LocalBackend pauses inference during training and is not " + "a supported async PipelineTrainer path. Set both " + "trainer_gpu_ids and inference_gpu_ids on the TrainableModel " + "_internal_config to use LocalBackend with PipelineTrainer." + ) + + def _validate_local_backend_train_config(self) -> None: + if self.loss_fn not in {"cispo", "ppo"}: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) only supports " + "loss_fn='cispo' or loss_fn='ppo'." + ) + if self.loss_fn_config is not None: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) requires " + "loss_fn_config=None." + ) + if not self.normalize_advantages: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) requires " + "normalize_advantages=True." + ) + if self.adam_params is not None: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) requires adam_params=None." + ) + + def _backend_train_kwargs(self, *, save_checkpoint: bool) -> dict[str, Any]: + if not self._is_local_backend(): + return { + "learning_rate": self.learning_rate, + "loss_fn": self.loss_fn, + "loss_fn_config": self.loss_fn_config, + "normalize_advantages": self.normalize_advantages, + "save_checkpoint": save_checkpoint, + "adam_params": self.adam_params, + } + + return { + "learning_rate": self.learning_rate, + "ppo": self.loss_fn == "ppo", + "save_checkpoint": save_checkpoint, + } + async def _skip_scenarios( self, scenarios: AsyncIterator[ScenarioT], count: int ) -> int: @@ -412,18 +479,14 @@ async def _training_stage(self) -> None: self._status.note_training_start(len(batch)) train_call_start = time.monotonic() + train_kwargs = self._backend_train_kwargs(save_checkpoint=should_checkpoint) if os.getenv("ART_TRAIN_STEP_LOG"): print(f"[train] step {expected_step} starting (batch={len(batch)})") try: result = await self.backend.train( self.model, batch, - learning_rate=self.learning_rate, - loss_fn=self.loss_fn, - loss_fn_config=self.loss_fn_config, - normalize_advantages=self.normalize_advantages, - save_checkpoint=should_checkpoint, - adam_params=self.adam_params, + **train_kwargs, ) except Exception: self._status.note_training_end() diff --git a/src/art/test/test_step_skipping.py b/src/art/test/test_step_skipping.py index f4c85a1b8..0a048b5ad 100755 --- a/src/art/test/test_step_skipping.py +++ b/src/art/test/test_step_skipping.py @@ -44,7 +44,7 @@ async def test_step_skipping(): # Set up backend with custom art path art_path = os.path.join(tmpdir, ".art") - with LocalBackend(path=art_path) as backend: + async with LocalBackend(path=art_path) as backend: # Create a test model model = TrainableModel( name=f"test-step-skip-{uuid.uuid4()}", diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index f24be80f4..d590b082e 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -35,7 +35,7 @@ from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers -from .train import gc_and_empty_cuda_cache, train +from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train logger = logging.getLogger(__name__) @@ -55,6 +55,15 @@ class SupportsLoadLora(Protocol): def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ... +class _StopTrainInputs: + """Dedicated sentinel for stopping the background trainer loop.""" + + +_STOP_TRAIN_INPUT = _StopTrainInputs() +_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S = 5.0 +_TRAIN_TASK_CANCEL_TIMEOUT_S = 1.0 + + def precalculate_new_logprobs( trainer: "GRPOTrainer", peft_model: "PeftModelForCausalLM", @@ -91,7 +100,7 @@ async def process_train_batch( packed_tensors: PackedTensors, config: types.TrainConfig, _config: dev.TrainConfig, - inputs_queue: asyncio.Queue[TrainInputs], + inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs], results_queue: asyncio.Queue[dict[str, float]], train_task: asyncio.Task[None], trainer: "GRPOTrainer", @@ -215,7 +224,7 @@ class UnslothState: tokenizer: PreTrainedTokenizerBase peft_model: peft.peft_model.PeftModelForCausalLM trainer: GRPOTrainer - inputs_queue: asyncio.Queue[TrainInputs] + inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs] results_queue: asyncio.Queue[dict[str, float]] _is_offloaded: bool = False _pinned_buffers: dict[str, torch.Tensor] | None = None @@ -316,6 +325,7 @@ class UnslothService: _vllm_log_file: Any = field(default=None, repr=False) _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 + _train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False) @property def is_dedicated(self) -> bool: @@ -326,6 +336,46 @@ def _next_lora_id(self) -> int: self._lora_id_counter += 1 return self._lora_id_counter + def _request_train_task_stop(self) -> asyncio.Task[None] | None: + train_task = self._train_task + if train_task is None: + return None + if train_task.done(): + return train_task + + # `_state` is a cached_property. Read from __dict__ directly so shutdown + # does not instantiate the full trainer state solely to stop a task. + state = self.__dict__.get("_state") + if isinstance(state, UnslothState): + state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) + return train_task + + async def _shutdown_train_task(self) -> None: + train_task = self._request_train_task_stop() + if train_task is None: + return + + try: + # Give the trainer loop time to consume the stop sentinel and exit + # normally before falling back to cancellation. + await asyncio.wait_for( + train_task, timeout=_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S + ) + except asyncio.TimeoutError: + train_task.cancel() + try: + await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_CANCEL_TIMEOUT_S) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + except asyncio.CancelledError: + pass + finally: + self._train_task = None + + async def aclose(self) -> None: + await self._shutdown_train_task() + self.close() + # ========================================================================= # Dedicated mode: vLLM subprocess lifecycle # ========================================================================= @@ -450,6 +500,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: def close(self) -> None: """Terminate vLLM subprocess if running.""" + self._request_train_task_stop() if self._vllm_process is None: return self._vllm_process.terminate() @@ -981,17 +1032,19 @@ def _state(self) -> UnslothState: trainer.create_optimizer() # Initialize queues - inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue() + inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs] = asyncio.Queue() results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() # Patch trainer _prepare_inputs() to pull from queue def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: - async def get_inputs() -> TrainInputs: + async def get_inputs() -> TrainInputs | _StopTrainInputs: return await inputs_queue.get() # Force otherwise synchronous _prepare_inputs() to yield # with nested asyncio.run() call inputs = asyncio.run(get_inputs()) + if isinstance(inputs, _StopTrainInputs): + raise StopTrainingLoop() return cast(dict[str, torch.Tensor], inputs) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index f6c42a2c0..8f9436914 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -39,6 +39,10 @@ } +class StopTrainingLoop(Exception): + """Signal that the background trainer loop should exit cleanly.""" + + def _canonicalize_upstream_metric_key(metric: str) -> str: if "/" in metric: return metric @@ -74,7 +78,10 @@ async def train( if not is_train_dict: trainer._metrics = {"train": defaultdict(list)} try: - trainer.train() + try: + trainer.train() + except StopTrainingLoop: + return finally: trainer.compute_loss = _compute_loss trainer.log = _log # ty:ignore[invalid-assignment] diff --git a/src/art/utils/benchmarking/pull_model_trajectories.py b/src/art/utils/benchmarking/pull_model_trajectories.py index b8a06bfe4..2708d5cf0 100644 --- a/src/art/utils/benchmarking/pull_model_trajectories.py +++ b/src/art/utils/benchmarking/pull_model_trajectories.py @@ -31,8 +31,9 @@ async def pull_model_trajectories(model: ArtModel) -> None: "Environment variable BACKUP_BUCKET is required but was not found." ) - # Use the LocalBackend context manager to work with the on-disk artefacts. - with LocalBackend() as backend: + # Use the LocalBackend async context manager so backend cleanup can await + # any background service shutdown before returning. + async with LocalBackend() as backend: print( f"Pulling trajectories for model '{model.name}' from S3 bucket '{bucket}'…", flush=True, diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py new file mode 100644 index 000000000..80e02e72f --- /dev/null +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -0,0 +1,295 @@ +import asyncio +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from art import TrainableModel, Trajectory, TrajectoryGroup +from art.dev.model import InternalModelConfig +from art.local import LocalBackend +from art.pipeline_trainer.trainer import PipelineTrainer +from art.utils.output_dirs import get_model_dir + + +def _make_group( + rewards: list[float], *, initial_policy_version: int | None = 0 +) -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=reward, + initial_policy_version=initial_policy_version, + messages_and_choices=[ + {"role": "user", "content": f"prompt-{idx}"}, + {"role": "assistant", "content": f"answer-{idx}"}, + ], + ) + for idx, reward in enumerate(rewards) + ] + ) + + +def _make_trainer( + *, + model: TrainableModel, + backend: object, + tmp_path: Path, + **kwargs: Any, +) -> PipelineTrainer: + return PipelineTrainer( + model=model, + backend=backend, # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_batch_size=1, + max_steps=1, + eval_fn=None, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_pipeline_trainer_preserves_default_backend_train_kwargs( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-default-backend-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + loss_fn_config = {"alpha": 0.1} + adam_params = object() + + trainer = _make_trainer( + model=model, + backend=backend, + tmp_path=tmp_path, + learning_rate=2e-5, + loss_fn="cispo", + loss_fn_config=loss_fn_config, + normalize_advantages=True, + adam_params=adam_params, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_count == 1 + assert backend.train.await_args.kwargs == { + "learning_rate": 2e-5, + "loss_fn": "cispo", + "loss_fn_config": loss_fn_config, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": adam_params, + } + + +@pytest.mark.asyncio +async def test_pipeline_trainer_translates_local_backend_kwargs( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-local-backend-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), + ) + backend = LocalBackend(path=str(tmp_path)) + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) # type: ignore[method-assign] + + trainer = _make_trainer( + model=model, + backend=backend, + tmp_path=tmp_path, + learning_rate=3e-5, + loss_fn="ppo", + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_count == 1 # type: ignore[attr-defined] + assert backend.train.await_args.kwargs == { # type: ignore[attr-defined] + "learning_rate": 3e-5, + "ppo": True, + "save_checkpoint": False, + } + + +@pytest.mark.asyncio +async def test_local_backend_close_awaits_async_service_cleanup( + tmp_path: Path, +) -> None: + backend = LocalBackend(path=str(tmp_path)) + calls: list[str] = [] + + class FakeService: + async def aclose(self) -> None: + calls.append("aclose") + + service = FakeService() + backend._services["test-service"] = service # type: ignore[assignment] + + with patch("art.local.backend.close_proxy") as close_proxy: + await backend.close() + + assert calls == ["aclose"] + close_proxy.assert_called_once_with(service) + + +@pytest.mark.asyncio +async def test_local_backend_async_context_manager_awaits_async_service_cleanup( + tmp_path: Path, +) -> None: + backend = LocalBackend(path=str(tmp_path)) + calls: list[str] = [] + + class FakeService: + async def aclose(self) -> None: + calls.append("aclose") + + service = FakeService() + backend._services["test-service"] = service # type: ignore[assignment] + + with patch("art.local.backend.close_proxy") as close_proxy: + async with backend: + pass + + assert calls == ["aclose"] + close_proxy.assert_called_once_with(service) + + +@pytest.mark.asyncio +async def test_local_backend_sync_context_manager_warns_in_running_loop( + tmp_path: Path, +) -> None: + backend = LocalBackend(path=str(tmp_path)) + calls: list[str] = [] + + class FakeService: + def close(self) -> None: + calls.append("close") + + service = FakeService() + backend._services["test-service"] = service # type: ignore[assignment] + + with patch("art.local.backend.close_proxy") as close_proxy: + with pytest.warns(RuntimeWarning, match="async with LocalBackend"): + with backend: + pass + + assert calls == ["close"] + close_proxy.assert_called_once_with(service) + + +def test_local_backend_sync_context_manager_warns_and_uses_sync_close( + tmp_path: Path, +) -> None: + backend = LocalBackend(path=str(tmp_path)) + calls: list[str] = [] + + class FakeService: + def close(self) -> None: + calls.append("close") + + async def aclose(self) -> None: + calls.append("aclose") + + service = FakeService() + backend._services["test-service"] = service # type: ignore[assignment] + + with patch("art.local.backend.close_proxy") as close_proxy: + with pytest.warns(RuntimeWarning, match="best-effort sync shutdown"): + with backend: + pass + + assert calls == ["close"] + close_proxy.assert_called_once_with(service) + + +@pytest.mark.parametrize( + ("trainer_kwargs", "match"), + [ + ({"loss_fn": "dro"}, "loss_fn='cispo' or loss_fn='ppo'"), + ({"loss_fn_config": {"clip": 0.2}}, "loss_fn_config=None"), + ({"normalize_advantages": False}, "normalize_advantages=True"), + ({"adam_params": object()}, "adam_params=None"), + ], +) +def test_pipeline_trainer_rejects_unsupported_local_backend_settings( + tmp_path: Path, + trainer_kwargs: dict[str, object], + match: str, +) -> None: + model = TrainableModel( + name="pipeline-local-backend-invalid", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), + ) + backend = LocalBackend(path=str(tmp_path)) + with pytest.raises(ValueError, match=match): + _make_trainer( + model=model, + backend=backend, + tmp_path=tmp_path, + **trainer_kwargs, + ) + + +def test_pipeline_trainer_rejects_shared_local_backend(tmp_path: Path) -> None: + model = TrainableModel( + name="pipeline-local-backend-shared", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + + with pytest.raises( + ValueError, match="only supports LocalBackend in dedicated mode" + ): + _make_trainer(model=model, backend=backend, tmp_path=tmp_path) + + +def test_local_backend_inference_name_prefers_served_step_in_dedicated_mode( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="local-backend-served-step", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), + ) + backend = LocalBackend(path=str(tmp_path)) + output_dir = Path(get_model_dir(model=model, art_path=str(tmp_path))) + (output_dir / "checkpoints" / "3").mkdir(parents=True) + backend._services[model.name] = cast(Any, SimpleNamespace(_latest_step=2)) + + assert backend._model_inference_name(model) == f"{model.name}@2" + assert backend._model_inference_name(model, step=3) == f"{model.name}@3" From 3067cc133f7e7f225c4f02302173193376c89cf8 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 13:49:53 -0700 Subject: [PATCH 02/13] test: Add dedicated LocalBackend smoke coverage --- .../test_pipeline_localbackend_dedicated.py | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 tests/integration/test_pipeline_localbackend_dedicated.py diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py new file mode 100644 index 000000000..2665ac102 --- /dev/null +++ b/tests/integration/test_pipeline_localbackend_dedicated.py @@ -0,0 +1,184 @@ +"""Dedicated LocalBackend smoke test for PipelineTrainer.""" + +import asyncio +import os +import tempfile +import uuid + +import openai +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("vllm") + +import art +from art.local import LocalBackend +from art.pipeline_trainer import PipelineTrainer + +DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" +DEFAULT_GPU_MEMORY_UTILIZATION = 0.2 +DEFAULT_MAX_MODEL_LEN = 2048 +DEFAULT_MAX_SEQ_LENGTH = 2048 + + +def get_base_model() -> str: + return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL) + + +def get_safe_gpu_memory_utilization() -> float: + requested = float( + os.environ.get( + "ART_TEST_GPU_MEMORY_UTILIZATION", + str(DEFAULT_GPU_MEMORY_UTILIZATION), + ) + ) + min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8")) + free_ratios: list[float] = [] + for device in (0, 1): + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + free_gib = free_bytes / (1024**3) + if free_gib < min_free_gib: + pytest.skip( + "Insufficient free GPU memory for dedicated LocalBackend smoke test: " + f"GPU {device} has {free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required." + ) + free_ratios.append(free_bytes / total_bytes) + return max(0.02, min(requested, min(free_ratios) * 0.8)) + + +def get_dedicated_vllm_test_config() -> art.dev.InternalModelConfig: + return { + "trainer_gpu_ids": [0], + "inference_gpu_ids": [1], + "engine_args": { + "gpu_memory_utilization": get_safe_gpu_memory_utilization(), + "max_model_len": int( + os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN)) + ), + "max_num_seqs": 8, + "enforce_eager": True, + }, + "init_args": { + "max_seq_length": int( + os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH)) + ), + }, + } + + +def reward_for_answer(text: str) -> float: + content = text.lower() + if "yes" in content: + return 1.0 + if "no" in content: + return 0.5 + if "maybe" in content: + return 0.25 + return 0.0 + + +async def assert_chat_logprobs( + client: openai.AsyncOpenAI, + model_name: str, +) -> None: + completion = await client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello."}], + model=model_name, + max_tokens=8, + timeout=60, + logprobs=True, + top_logprobs=0, + ) + assert completion.choices[0].logprobs is not None + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for dedicated LocalBackend PipelineTrainer test", +) +async def test_pipeline_trainer_local_backend_dedicated_smoke() -> None: + model_name = f"test-pipeline-local-dedicated-{uuid.uuid4().hex[:8]}" + prompts = [ + "Say yes", + "Say no", + "Say maybe", + "Say hello", + "Say yes again", + "Say no again", + ] + client: openai.AsyncOpenAI | None = None + + async def rollout_fn( + model: art.TrainableModel, + scenario: dict[str, str], + _config: None, + ) -> art.TrajectoryGroup: + await asyncio.sleep(0.2) + messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}] + assert client is not None + completion = await client.chat.completions.create( + messages=messages, + model=model.get_inference_name(), + max_tokens=10, + timeout=60, + temperature=1, + n=2, + logprobs=True, + top_logprobs=0, + ) + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + for choice in completion.choices + ] + ) + + async def scenario_iter(): + for prompt in prompts: + yield {"prompt": prompt} + + with tempfile.TemporaryDirectory() as tmpdir: + async with LocalBackend(path=tmpdir) as backend: + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + _internal_config=get_dedicated_vllm_test_config(), + ) + client: openai.AsyncOpenAI | None = None + try: + await model.register(backend) + client = model.openai_client() + trainer = PipelineTrainer( + model=model, + backend=backend, + rollout_fn=rollout_fn, + scenarios=scenario_iter(), + config=None, + num_rollout_workers=2, + min_batch_size=1, + max_batch_size=1, + max_steps=2, + loss_fn="cispo", + eval_fn=None, + ) + + await trainer.train() + + latest_step = await model.get_step() + assert latest_step >= 2 + + await assert_chat_logprobs(client, model.get_inference_name(step=0)) + await assert_chat_logprobs( + client, model.get_inference_name(step=latest_step) + ) + + model_ids = [m.id async for m in client.models.list()] + assert f"{model.name}@0" in model_ids + assert f"{model.name}@{latest_step}" in model_ids + finally: + if client is not None: + await client.close() From e4c6b2be31708ded2c350d42847236431910d1e1 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 13:50:01 -0700 Subject: [PATCH 03/13] docs: Document dedicated LocalBackend pipeline support --- docs/features/checkpoint-forking.mdx | 4 ++-- docs/fundamentals/art-backend.mdx | 24 ++++++++++++++++++++++++ docs/fundamentals/training-loop.mdx | 2 ++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/docs/features/checkpoint-forking.mdx b/docs/features/checkpoint-forking.mdx index c3d3603df..4a214c855 100644 --- a/docs/features/checkpoint-forking.mdx +++ b/docs/features/checkpoint-forking.mdx @@ -31,7 +31,7 @@ import art from art.local import LocalBackend async def train(): - with LocalBackend() as backend: + async with LocalBackend() as backend: # Create a new model that will fork from an existing checkpoint model = art.TrainableModel( name="my-model-v2", @@ -115,7 +115,7 @@ low_lr_model = art.TrainableModel( ) async def experiment(): - with LocalBackend() as backend: + async with LocalBackend() as backend: # Fork the model from the base model await backend._experimental_fork_checkpoint( low_lr_model, diff --git a/docs/fundamentals/art-backend.mdx b/docs/fundamentals/art-backend.mdx index f65473a99..9e6019c0b 100644 --- a/docs/fundamentals/art-backend.mdx +++ b/docs/fundamentals/art-backend.mdx @@ -73,6 +73,30 @@ backend = LocalBackend( ) ``` +If you're using `PipelineTrainer`, `LocalBackend` is currently supported only in dedicated mode, where training and inference run on separate GPUs. + +```python +from art import TrainableModel +from art.dev import InternalModelConfig +from art.local import LocalBackend + +backend = LocalBackend(path="./.art") +model = TrainableModel( + name="pipeline-localbackend", + project="my-project", + base_model="Qwen/Qwen3-0.6B", + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), +) +``` + +Shared `LocalBackend` still pauses inference during training, so ART rejects that configuration for `PipelineTrainer`. + +In dedicated mode, a new checkpoint becomes the default inference target only after its LoRA has been reloaded into vLLM. That checkpoint publication flow is backend-specific, so `save_checkpoint` does not have identical semantics across every ART backend. +Requests that are already in flight keep using the adapter they started with; the reload only affects subsequent routing to the latest served step. + ## Using a backend Once initialized, a backend can be used in the same way regardless of whether it runs locally or remotely. diff --git a/docs/fundamentals/training-loop.mdx b/docs/fundamentals/training-loop.mdx index b7bc8fe99..4c8f4f75f 100644 --- a/docs/fundamentals/training-loop.mdx +++ b/docs/fundamentals/training-loop.mdx @@ -22,6 +22,8 @@ ART's functionality is divided into a [**client**](/fundamentals/art-client) and This training loop runs until a specified number of inference and training iterations have completed. +This describes the default shared-resource loop. `PipelineTrainer` can also run with `LocalBackend` in dedicated mode, where training and inference stay on separate GPUs and the latest served step advances only after vLLM reloads the new LoRA. + Training and inference use both the ART **client** and **backend**. Learn more by following the links below!
From 47579de3e8dbe8ad97400f5fe1708db3dc1676a9 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 14:52:01 -0700 Subject: [PATCH 04/13] refactor: simplify LocalBackend pipeline trainer integration --- src/art/local/backend.py | 96 ++++++------- src/art/pipeline_trainer/trainer.py | 62 +++------ src/art/unsloth/service.py | 59 +++----- src/art/unsloth/train.py | 7 +- .../test_pipeline_localbackend_dedicated.py | 76 +++++----- .../test_pipeline_trainer_local_backend.py | 131 ++++++------------ 6 files changed, 165 insertions(+), 266 deletions(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index dfea5c124..771b81632 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -168,25 +168,6 @@ def __exit__( exc: BaseException | None, tb: TracebackType | None, ) -> None: - try: - asyncio.get_running_loop() - except RuntimeError: - running_loop = False - else: - running_loop = True - - if running_loop or any( - getattr(service, "aclose", None) is not None - for service in self._services.values() - ): - warnings.warn( - "LocalBackend used as a sync context manager. Cleanup uses the " - "best-effort sync shutdown path and cannot await service " - "teardown safely here; use `async with LocalBackend(...)` or " - "`await backend.close()` instead.", - RuntimeWarning, - stacklevel=2, - ) self._close() async def __aexit__( @@ -201,20 +182,18 @@ async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - for _, service in self._services.items(): - # Keep this logic aligned with _close(), but avoid double-closing - # services that expose an awaited aclose() path. + for service in self._services.values(): aclose = getattr(service, "aclose", None) - if aclose is not None: - await aclose() - else: + if aclose is None: close = getattr(service, "close", None) if close is not None: close() + else: + await aclose() close_proxy(service) def _close(self) -> None: - for _, service in self._services.items(): + for service in self._services.values(): close = getattr(service, "close", None) if close is not None: close() @@ -259,35 +238,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str: If None, returns name for latest checkpoint (step 0 initially). """ - def _served_step() -> int | None: - if not isinstance(model, TrainableModel): - return None - if model.name not in self._services: - return None + requested_step = step + + if step is None and isinstance(model, TrainableModel): from ..dev.validate import is_dedicated_mode - if not is_dedicated_mode( + service = self._services.get(model.name) + if service is not None and is_dedicated_mode( model._internal_config or dev.InternalModelConfig() ): - return None - loaded_step = getattr(self._services[model.name], "_latest_step", None) - return loaded_step if isinstance(loaded_step, int) else None - - # For LocalBackend, vLLM always serves LoRA adapters with @step suffix - # Default to step 0 when not specified (the initial checkpoint created at registration) - if step is not None: - actual_step = step - else: - # In dedicated mode the service tracks which adapter vLLM has - # actually loaded. Reading the filesystem would race: the checkpoint - # directory appears before the HTTP reload completes. - actual_step = _served_step() - if actual_step is None: - actual_step = self.__get_step(model) - name = f"{model.name}@{actual_step}" + loaded_step = getattr(service, "_latest_step", None) + if isinstance(loaded_step, int): + step = loaded_step + + if step is None: + # The checkpoint directory is written before dedicated-mode + # vLLM finishes reloading the new adapter. + step = self.__get_step(model) + name = f"{model.name}@{step}" logger.debug( - f"[BACKEND] _model_inference_name: step_arg={step} " - f"actual_step={actual_step} -> {name}" + f"[BACKEND] _model_inference_name: step_arg={requested_step} " + f"actual_step={step} -> {name}" ) return name @@ -552,12 +523,14 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, + loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo", + loss_fn_config: dict | None = None, + normalize_advantages: bool = True, + adam_params: object | None = None, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, - # RL algorithm settings - ppo: bool = False, epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -594,6 +567,14 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. + loss_fn: RL loss function. LocalBackend currently supports + "cispo" and "ppo". + loss_fn_config: Additional loss-function config. Not supported by + LocalBackend. + normalize_advantages: Whether to normalize advantages. LocalBackend + currently requires True. + adam_params: Custom optimizer params. Not supported by + LocalBackend. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. Tokens diverging more from the reference get reduced advantages. Defaults to 0.0 (disabled). @@ -603,7 +584,6 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. - ppo: Whether to use PPO clipping. Defaults to False. epsilon: Clip epsilon for importance sampling. Defaults based on ppo. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages @@ -647,6 +627,14 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) + if loss_fn not in {"cispo", "ppo"}: + raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.") + if loss_fn_config is not None: + raise ValueError("LocalBackend requires loss_fn_config=None.") + if not normalize_advantages: + raise ValueError("LocalBackend requires normalize_advantages=True.") + if adam_params is not None: + raise ValueError("LocalBackend requires adam_params=None.") # Build config objects from explicit kwargs config = TrainConfig( @@ -659,7 +647,7 @@ async def train( # type: ignore[override] "kl_penalty_coef": kl_penalty_coef, "mask_prob_ratio": mask_prob_ratio, "plot_tensors": plot_tensors, - "ppo": ppo, + "ppo": loss_fn == "ppo", "precalculate_logprobs": precalculate_logprobs, "scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev, "scale_rewards": scale_rewards, diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index c2ed80f92..77966857a 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -278,35 +278,22 @@ async def _notify_policy() -> None: except asyncio.QueueFull: loop.create_task(self._output_queue.put(None)) - def _is_local_backend(self) -> bool: - from art.local.backend import LocalBackend - - return isinstance(self.backend, LocalBackend) - - def _local_backend_is_dedicated(self) -> bool: - if not isinstance(self.model, art.TrainableModel): - return False + def _validate_backend_support(self) -> None: from art.dev.validate import is_dedicated_mode + from art.local.backend import LocalBackend - return is_dedicated_mode( - self.model._internal_config or art.dev.InternalModelConfig() - ) - - def _validate_backend_support(self) -> None: - if not self._is_local_backend(): - return - if self._local_backend_is_dedicated(): - self._validate_local_backend_train_config() + if not isinstance(self.backend, LocalBackend): return - raise ValueError( - "PipelineTrainer only supports LocalBackend in dedicated mode. " - "Shared LocalBackend pauses inference during training and is not " - "a supported async PipelineTrainer path. Set both " - "trainer_gpu_ids and inference_gpu_ids on the TrainableModel " - "_internal_config to use LocalBackend with PipelineTrainer." - ) - def _validate_local_backend_train_config(self) -> None: + model_config = self.model._internal_config or art.dev.InternalModelConfig() + if not is_dedicated_mode(model_config): + raise ValueError( + "PipelineTrainer only supports LocalBackend in dedicated mode. " + "Shared LocalBackend pauses inference during training and is not " + "a supported async PipelineTrainer path. Set both " + "trainer_gpu_ids and inference_gpu_ids on the TrainableModel " + "_internal_config to use LocalBackend with PipelineTrainer." + ) if self.loss_fn not in {"cispo", "ppo"}: raise ValueError( "PipelineTrainer + LocalBackend(dedicated) only supports " @@ -327,23 +314,6 @@ def _validate_local_backend_train_config(self) -> None: "PipelineTrainer + LocalBackend(dedicated) requires adam_params=None." ) - def _backend_train_kwargs(self, *, save_checkpoint: bool) -> dict[str, Any]: - if not self._is_local_backend(): - return { - "learning_rate": self.learning_rate, - "loss_fn": self.loss_fn, - "loss_fn_config": self.loss_fn_config, - "normalize_advantages": self.normalize_advantages, - "save_checkpoint": save_checkpoint, - "adam_params": self.adam_params, - } - - return { - "learning_rate": self.learning_rate, - "ppo": self.loss_fn == "ppo", - "save_checkpoint": save_checkpoint, - } - async def _skip_scenarios( self, scenarios: AsyncIterator[ScenarioT], count: int ) -> int: @@ -479,14 +449,18 @@ async def _training_stage(self) -> None: self._status.note_training_start(len(batch)) train_call_start = time.monotonic() - train_kwargs = self._backend_train_kwargs(save_checkpoint=should_checkpoint) if os.getenv("ART_TRAIN_STEP_LOG"): print(f"[train] step {expected_step} starting (batch={len(batch)})") try: result = await self.backend.train( self.model, batch, - **train_kwargs, + learning_rate=self.learning_rate, + loss_fn=self.loss_fn, + loss_fn_config=self.loss_fn_config, + normalize_advantages=self.normalize_advantages, + save_checkpoint=should_checkpoint, + adam_params=self.adam_params, ) except Exception: self._status.note_training_end() diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index d590b082e..5b6a563c2 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -60,8 +60,8 @@ class _StopTrainInputs: _STOP_TRAIN_INPUT = _StopTrainInputs() -_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S = 5.0 -_TRAIN_TASK_CANCEL_TIMEOUT_S = 1.0 +_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0 +_TrainLoopInput = TrainInputs | _StopTrainInputs def precalculate_new_logprobs( @@ -100,7 +100,7 @@ async def process_train_batch( packed_tensors: PackedTensors, config: types.TrainConfig, _config: dev.TrainConfig, - inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs], + inputs_queue: asyncio.Queue[_TrainLoopInput], results_queue: asyncio.Queue[dict[str, float]], train_task: asyncio.Task[None], trainer: "GRPOTrainer", @@ -224,7 +224,7 @@ class UnslothState: tokenizer: PreTrainedTokenizerBase peft_model: peft.peft_model.PeftModelForCausalLM trainer: GRPOTrainer - inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs] + inputs_queue: asyncio.Queue[_TrainLoopInput] results_queue: asyncio.Queue[dict[str, float]] _is_offloaded: bool = False _pinned_buffers: dict[str, torch.Tensor] | None = None @@ -336,44 +336,22 @@ def _next_lora_id(self) -> int: self._lora_id_counter += 1 return self._lora_id_counter - def _request_train_task_stop(self) -> asyncio.Task[None] | None: + async def aclose(self) -> None: train_task = self._train_task - if train_task is None: - return None - if train_task.done(): - return train_task - - # `_state` is a cached_property. Read from __dict__ directly so shutdown - # does not instantiate the full trainer state solely to stop a task. - state = self.__dict__.get("_state") - if isinstance(state, UnslothState): - state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) - return train_task - - async def _shutdown_train_task(self) -> None: - train_task = self._request_train_task_stop() - if train_task is None: + self._train_task = None + if train_task is None or train_task.done(): + self.close() return + # `_state` is a cached_property. Read from __dict__ directly so + # closing does not instantiate trainer state only to stop a task. + state = self.__dict__.get("_state") + assert isinstance(state, UnslothState) + state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) try: - # Give the trainer loop time to consume the stop sentinel and exit - # normally before falling back to cancellation. - await asyncio.wait_for( - train_task, timeout=_TRAIN_TASK_GRACEFUL_SHUTDOWN_TIMEOUT_S - ) + await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S) except asyncio.TimeoutError: train_task.cancel() - try: - await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_CANCEL_TIMEOUT_S) - except (asyncio.CancelledError, asyncio.TimeoutError): - pass - except asyncio.CancelledError: - pass - finally: - self._train_task = None - - async def aclose(self) -> None: - await self._shutdown_train_task() self.close() # ========================================================================= @@ -500,7 +478,6 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: def close(self) -> None: """Terminate vLLM subprocess if running.""" - self._request_train_task_stop() if self._vllm_process is None: return self._vllm_process.terminate() @@ -646,7 +623,7 @@ async def _train_dedicated( await self._state.results_queue.join() - if not hasattr(self, "_train_task") or self._train_task is None: + if self._train_task is None: self._train_task = asyncio.create_task( train( trainer=self._state.trainer, @@ -736,7 +713,7 @@ async def _train_shared( await self._state.results_queue.join() # If we haven't already, start the training task - if not hasattr(self, "_train_task") or self._train_task is None: + if self._train_task is None: self._train_task = asyncio.create_task( train( trainer=self._state.trainer, @@ -1032,12 +1009,12 @@ def _state(self) -> UnslothState: trainer.create_optimizer() # Initialize queues - inputs_queue: asyncio.Queue[TrainInputs | _StopTrainInputs] = asyncio.Queue() + inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue() results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() # Patch trainer _prepare_inputs() to pull from queue def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: - async def get_inputs() -> TrainInputs | _StopTrainInputs: + async def get_inputs() -> _TrainLoopInput: return await inputs_queue.get() # Force otherwise synchronous _prepare_inputs() to yield diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 8f9436914..433879040 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -78,10 +78,9 @@ async def train( if not is_train_dict: trainer._metrics = {"train": defaultdict(list)} try: - try: - trainer.train() - except StopTrainingLoop: - return + trainer.train() + except StopTrainingLoop: + return finally: trainer.compute_loss = _compute_loss trainer.log = _log # ty:ignore[invalid-assignment] diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py index 2665ac102..d6d04bc7b 100644 --- a/tests/integration/test_pipeline_localbackend_dedicated.py +++ b/tests/integration/test_pipeline_localbackend_dedicated.py @@ -106,39 +106,6 @@ async def test_pipeline_trainer_local_backend_dedicated_smoke() -> None: "Say yes again", "Say no again", ] - client: openai.AsyncOpenAI | None = None - - async def rollout_fn( - model: art.TrainableModel, - scenario: dict[str, str], - _config: None, - ) -> art.TrajectoryGroup: - await asyncio.sleep(0.2) - messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}] - assert client is not None - completion = await client.chat.completions.create( - messages=messages, - model=model.get_inference_name(), - max_tokens=10, - timeout=60, - temperature=1, - n=2, - logprobs=True, - top_logprobs=0, - ) - return art.TrajectoryGroup( - [ - art.Trajectory( - messages_and_choices=[*messages, choice], - reward=reward_for_answer(choice.message.content or ""), - ) - for choice in completion.choices - ] - ) - - async def scenario_iter(): - for prompt in prompts: - yield {"prompt": prompt} with tempfile.TemporaryDirectory() as tmpdir: async with LocalBackend(path=tmpdir) as backend: @@ -148,10 +115,44 @@ async def scenario_iter(): base_model=get_base_model(), _internal_config=get_dedicated_vllm_test_config(), ) - client: openai.AsyncOpenAI | None = None + + async def scenario_iter(): + for prompt in prompts: + yield {"prompt": prompt} + + await model.register(backend) + client = model.openai_client() try: - await model.register(backend) - client = model.openai_client() + + async def rollout_fn( + rollout_model: art.TrainableModel, + scenario: dict[str, str], + _config: None, + ) -> art.TrajectoryGroup: + await asyncio.sleep(0.2) + messages: art.Messages = [ + {"role": "user", "content": scenario["prompt"]} + ] + completion = await client.chat.completions.create( + messages=messages, + model=rollout_model.get_inference_name(), + max_tokens=10, + timeout=60, + temperature=1, + n=2, + logprobs=True, + top_logprobs=0, + ) + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + for choice in completion.choices + ] + ) + trainer = PipelineTrainer( model=model, backend=backend, @@ -180,5 +181,4 @@ async def scenario_iter(): assert f"{model.name}@0" in model_ids assert f"{model.name}@{latest_step}" in model_ids finally: - if client is not None: - await client.close() + await client.close() diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index 80e02e72f..e63fdb59a 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -13,14 +13,12 @@ from art.utils.output_dirs import get_model_dir -def _make_group( - rewards: list[float], *, initial_policy_version: int | None = 0 -) -> TrajectoryGroup: +def _make_group(rewards: list[float]) -> TrajectoryGroup: return TrajectoryGroup( [ Trajectory( reward=reward, - initial_policy_version=initial_policy_version, + initial_policy_version=0, messages_and_choices=[ {"role": "user", "content": f"prompt-{idx}"}, {"role": "assistant", "content": f"answer-{idx}"}, @@ -35,7 +33,6 @@ def _make_trainer( *, model: TrainableModel, backend: object, - tmp_path: Path, **kwargs: Any, ) -> PipelineTrainer: return PipelineTrainer( @@ -54,9 +51,7 @@ def _make_trainer( @pytest.mark.asyncio -async def test_pipeline_trainer_preserves_default_backend_train_kwargs( - tmp_path: Path, -) -> None: +async def test_pipeline_trainer_preserves_backend_train_kwargs(tmp_path: Path) -> None: model = TrainableModel( name="pipeline-default-backend-kwargs", project="pipeline-tests", @@ -71,7 +66,6 @@ async def test_pipeline_trainer_preserves_default_backend_train_kwargs( trainer = _make_trainer( model=model, backend=backend, - tmp_path=tmp_path, learning_rate=2e-5, loss_fn="cispo", loss_fn_config=loss_fn_config, @@ -84,7 +78,6 @@ async def test_pipeline_trainer_preserves_default_backend_train_kwargs( await trainer._training_stage() - assert backend.train.await_count == 1 assert backend.train.await_args.kwargs == { "learning_rate": 2e-5, "loss_fn": "cispo", @@ -96,7 +89,7 @@ async def test_pipeline_trainer_preserves_default_backend_train_kwargs( @pytest.mark.asyncio -async def test_pipeline_trainer_translates_local_backend_kwargs( +async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( tmp_path: Path, ) -> None: model = TrainableModel( @@ -115,7 +108,6 @@ async def test_pipeline_trainer_translates_local_backend_kwargs( trainer = _make_trainer( model=model, backend=backend, - tmp_path=tmp_path, learning_rate=3e-5, loss_fn="ppo", ) @@ -125,37 +117,56 @@ async def test_pipeline_trainer_translates_local_backend_kwargs( await trainer._training_stage() - assert backend.train.await_count == 1 # type: ignore[attr-defined] assert backend.train.await_args.kwargs == { # type: ignore[attr-defined] "learning_rate": 3e-5, - "ppo": True, + "loss_fn": "ppo", + "loss_fn_config": None, + "normalize_advantages": True, "save_checkpoint": False, + "adam_params": None, } @pytest.mark.asyncio -async def test_local_backend_close_awaits_async_service_cleanup( - tmp_path: Path, -) -> None: +async def test_local_backend_train_translates_loss_fn(tmp_path: Path) -> None: + model = TrainableModel( + name="local-backend-train-translation", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) backend = LocalBackend(path=str(tmp_path)) - calls: list[str] = [] - - class FakeService: - async def aclose(self) -> None: - calls.append("aclose") - - service = FakeService() - backend._services["test-service"] = service # type: ignore[assignment] + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group([1.0])], + loss_fn="ppo", + save_checkpoint=False, + ) - with patch("art.local.backend.close_proxy") as close_proxy: - await backend.close() - - assert calls == ["aclose"] - close_proxy.assert_called_once_with(service) + assert result.step == 1 + assert seen["config"].learning_rate == 5e-6 + assert seen["dev_config"]["ppo"] is True @pytest.mark.asyncio -async def test_local_backend_async_context_manager_awaits_async_service_cleanup( +async def test_local_backend_async_context_manager_awaits_async_cleanup( tmp_path: Path, ) -> None: backend = LocalBackend(path=str(tmp_path)) @@ -166,7 +177,7 @@ async def aclose(self) -> None: calls.append("aclose") service = FakeService() - backend._services["test-service"] = service # type: ignore[assignment] + backend._services["test-service"] = cast(Any, service) with patch("art.local.backend.close_proxy") as close_proxy: async with backend: @@ -176,54 +187,6 @@ async def aclose(self) -> None: close_proxy.assert_called_once_with(service) -@pytest.mark.asyncio -async def test_local_backend_sync_context_manager_warns_in_running_loop( - tmp_path: Path, -) -> None: - backend = LocalBackend(path=str(tmp_path)) - calls: list[str] = [] - - class FakeService: - def close(self) -> None: - calls.append("close") - - service = FakeService() - backend._services["test-service"] = service # type: ignore[assignment] - - with patch("art.local.backend.close_proxy") as close_proxy: - with pytest.warns(RuntimeWarning, match="async with LocalBackend"): - with backend: - pass - - assert calls == ["close"] - close_proxy.assert_called_once_with(service) - - -def test_local_backend_sync_context_manager_warns_and_uses_sync_close( - tmp_path: Path, -) -> None: - backend = LocalBackend(path=str(tmp_path)) - calls: list[str] = [] - - class FakeService: - def close(self) -> None: - calls.append("close") - - async def aclose(self) -> None: - calls.append("aclose") - - service = FakeService() - backend._services["test-service"] = service # type: ignore[assignment] - - with patch("art.local.backend.close_proxy") as close_proxy: - with pytest.warns(RuntimeWarning, match="best-effort sync shutdown"): - with backend: - pass - - assert calls == ["close"] - close_proxy.assert_called_once_with(service) - - @pytest.mark.parametrize( ("trainer_kwargs", "match"), [ @@ -248,12 +211,11 @@ def test_pipeline_trainer_rejects_unsupported_local_backend_settings( inference_gpu_ids=[1], ), ) - backend = LocalBackend(path=str(tmp_path)) + with pytest.raises(ValueError, match=match): _make_trainer( model=model, - backend=backend, - tmp_path=tmp_path, + backend=LocalBackend(path=str(tmp_path)), **trainer_kwargs, ) @@ -265,12 +227,11 @@ def test_pipeline_trainer_rejects_shared_local_backend(tmp_path: Path) -> None: base_model="test-model", base_path=str(tmp_path), ) - backend = LocalBackend(path=str(tmp_path)) with pytest.raises( ValueError, match="only supports LocalBackend in dedicated mode" ): - _make_trainer(model=model, backend=backend, tmp_path=tmp_path) + _make_trainer(model=model, backend=LocalBackend(path=str(tmp_path))) def test_local_backend_inference_name_prefers_served_step_in_dedicated_mode( From dbefea60c585b1d07f7d6a3f88dde4ce41035d82 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 15:12:48 -0700 Subject: [PATCH 05/13] refactor: Narrow LocalBackend train types --- src/art/local/backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 771b81632..c876bd319 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -523,7 +523,7 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, - loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo", + loss_fn: Literal["cispo", "ppo"] = "cispo", loss_fn_config: dict | None = None, normalize_advantages: bool = True, adam_params: object | None = None, @@ -584,7 +584,7 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. - epsilon: Clip epsilon for importance sampling. Defaults based on ppo. + epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages in range [-1.0, 1.0]. Defaults to 0.0 (balanced). From 05f39ef6f4cfb000c9099be12aaddd8ce93e9769 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 15:52:39 -0700 Subject: [PATCH 06/13] test: Add max batch size regression coverage --- tests/unit/test_pipeline_trainer_batching.py | 67 ++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/unit/test_pipeline_trainer_batching.py diff --git a/tests/unit/test_pipeline_trainer_batching.py b/tests/unit/test_pipeline_trainer_batching.py new file mode 100644 index 000000000..0ab412e8f --- /dev/null +++ b/tests/unit/test_pipeline_trainer_batching.py @@ -0,0 +1,67 @@ +import asyncio +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from art import TrainableModel, Trajectory, TrajectoryGroup +from art.pipeline_trainer.trainer import PipelineTrainer + + +def _make_group() -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=reward, + initial_policy_version=0, + messages_and_choices=[ + {"role": "user", "content": f"prompt-{idx}"}, + {"role": "assistant", "content": f"answer-{idx}"}, + ], + ) + for idx, reward in enumerate([0.0, 1.0]) + ] + ) + + +@pytest.mark.asyncio +async def test_collect_batch_respects_max_batch_size(tmp_path: Path) -> None: + model = TrainableModel( + name="pipeline-max-batch-size-test", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + trainer = PipelineTrainer( + model=model, + backend=MagicMock(), # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_batch_size=2, + max_steps=1, + eval_fn=None, + ) + trainer._output_queue = asyncio.Queue() + + first = _make_group() + second = _make_group() + third = _make_group() + await trainer._output_queue.put(first) + await trainer._output_queue.put(second) + await trainer._output_queue.put(third) + await trainer._output_queue.put(None) + + batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0) + + assert batch == [first, second] + assert discarded == 0 + assert not saw_sentinel + + batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0) + + assert batch == [third] + assert discarded == 0 + assert saw_sentinel From e7480711bde34b530bf761b6ab3e46377092e9c6 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Wed, 18 Mar 2026 15:52:47 -0700 Subject: [PATCH 07/13] fix: Respect max batch size in PipelineTrainer --- src/art/pipeline_trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 77966857a..302cbe78c 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -552,7 +552,7 @@ async def _collect_batch( continue batch.append(item) - while not saw_sentinel: + while not saw_sentinel and len(batch) < self.max_batch_size: try: item = self._output_queue.get_nowait() except asyncio.QueueEmpty: From 9775490d31583e7c8e0a6e7f18e6b45d47b393a1 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Thu, 19 Mar 2026 16:50:46 -0700 Subject: [PATCH 08/13] feat: Add sampled KL support to pipeline backends --- src/art/dev/train.py | 1 + src/art/local/backend.py | 12 +- src/art/loss.py | 9 +- src/art/pipeline_trainer/trainer.py | 13 +++ src/art/test/test_kl_advantage.py | 49 +++++++- src/art/tinker_native/backend.py | 108 ++++++++++++++++++ src/art/types.py | 1 + .../test_pipeline_trainer_local_backend.py | 77 +++++++++++++ tests/unit/test_tinker_native_kl.py | 97 ++++++++++++++++ 9 files changed, 364 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_tinker_native_kl.py diff --git a/src/art/dev/train.py b/src/art/dev/train.py index b0e232c59..5da3e1ab9 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False): ] kimi_k2_tau: float | None kl_penalty_coef: float + kl_penalty_source: Literal["current_learner", "sample"] kl_ref_adapter_path: str | None logprob_calculation_chunk_size: int mask_prob_ratio: bool diff --git a/src/art/local/backend.py b/src/art/local/backend.py index c876bd319..723383adc 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -531,6 +531,7 @@ async def train( # type: ignore[override] kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner", epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -584,6 +585,11 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. + kl_penalty_source: Which policy's logprobs to compare against the + reference when building the centered KL penalty. Use + "current_learner" to match the original ART implementation, or + "sample" to shape from the rollout policy logprobs, which is + usually better for async/off-policy workloads. epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages @@ -635,16 +641,20 @@ async def train( # type: ignore[override] raise ValueError("LocalBackend requires normalize_advantages=True.") if adam_params is not None: raise ValueError("LocalBackend requires adam_params=None.") + assert kl_penalty_source in {"current_learner", "sample"} # Build config objects from explicit kwargs config = TrainConfig( - learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef + learning_rate=learning_rate, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_source=kl_penalty_source, ) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, "allow_training_without_logprobs": allow_training_without_logprobs, "importance_sampling_level": importance_sampling_level, "kl_penalty_coef": kl_penalty_coef, + "kl_penalty_source": kl_penalty_source, "mask_prob_ratio": mask_prob_ratio, "plot_tensors": plot_tensors, "ppo": loss_fn == "ppo", diff --git a/src/art/loss.py b/src/art/loss.py index 5a73d7b72..59cfa46a7 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -95,7 +95,14 @@ def loss_fn( kl_policy_ref: torch.Tensor | None = None kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0) if kl_penalty_coef > 0 and ref_logprobs is not None: - kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask + match experimental_config.get("kl_penalty_source", "current_learner"): + case "sample": + kl_source_logprobs = old_logprobs.detach() + case "current_learner": + kl_source_logprobs = new_logprobs.detach() + case other: + raise AssertionError(other) + kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6) kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask advantages = advantages + kl_penalty diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 302cbe78c..a50e6d57a 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -78,6 +78,8 @@ def __init__( loss_fn_config: dict | None = None, normalize_advantages: bool = True, adam_params: object | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, max_steps: int | None = None, # Discard handling discard_queue_multiplier: int = 100, @@ -129,6 +131,8 @@ def __init__( self.loss_fn_config = loss_fn_config self.normalize_advantages = normalize_advantages self.adam_params = adam_params + self.kl_penalty_coef = kl_penalty_coef + self.kl_penalty_reference_step = kl_penalty_reference_step self.max_steps = max_steps self._status_log_interval_seconds = log_interval_seconds self.eval_every_n_steps = eval_every_n_steps @@ -452,6 +456,14 @@ async def _training_stage(self) -> None: if os.getenv("ART_TRAIN_STEP_LOG"): print(f"[train] step {expected_step} starting (batch={len(batch)})") try: + kl_train_kwargs: dict[str, object] = {} + if self.kl_penalty_coef > 0.0: + kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef + kl_train_kwargs["kl_penalty_source"] = "sample" + if self.kl_penalty_reference_step is not None: + kl_train_kwargs["kl_penalty_reference_step"] = ( + self.kl_penalty_reference_step + ) result = await self.backend.train( self.model, batch, @@ -461,6 +473,7 @@ async def _training_stage(self) -> None: normalize_advantages=self.normalize_advantages, save_checkpoint=should_checkpoint, adam_params=self.adam_params, + **kl_train_kwargs, ) except Exception: self._status.note_training_end() diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py index d944efc62..82c0f2a25 100644 --- a/src/art/test/test_kl_advantage.py +++ b/src/art/test/test_kl_advantage.py @@ -2,7 +2,7 @@ import torch -from art.loss import Loss, loss_fn +from art.loss import loss_fn, shift_tensor def _make_inputs( @@ -114,3 +114,50 @@ def test_kl_advantage_does_not_affect_when_no_ref(): loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5}) assert loss.kl_policy_ref is None + + +def test_kl_advantage_can_use_sample_logprobs() -> None: + """Sample-source KL should use stored rollout logprobs rather than learner logprobs.""" + inputs = _make_inputs(seq_len=8) + inputs["logprobs"] = torch.tensor( + [[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32 + ) + new_logprobs = torch.tensor( + [[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32 + ) + ref_logprobs = torch.full((1, 8), -0.5) + assistant_mask = shift_tensor(inputs["assistant_mask"], False).to( + new_logprobs.dtype + ) + sampled_logprobs = torch.where( + torch.isnan(shift_tensor(inputs["logprobs"], float("nan"))), + new_logprobs.detach(), + shift_tensor(inputs["logprobs"], float("nan")), + ) + expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + + sample_loss = loss_fn( + inputs, + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"}, + ) + learner_loss = loss_fn( + inputs, + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"}, + ) + + assert sample_loss.kl_policy_ref is not None + assert learner_loss.kl_policy_ref is not None + assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl) + assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl) + assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref) diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index c1687bf7f..fe64e425a 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -24,6 +24,7 @@ from openai.types.chat.completion_create_params import CompletionCreateParams from openai.types.completion_usage import CompletionUsage import tinker +import torch import uvicorn from art.tinker.cookbook_v import renderers, tokenizer_utils @@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str: return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric) +async def _apply_kl_penalty( + datums: list[tinker.Datum], + reference_sampling_client: tinker.SamplingClient, + kl_penalty_coef: float, +) -> dict[str, float]: + assert datums + assert kl_penalty_coef > 0.0 + + full_sequences: list[tinker.ModelInput] = [] + sampled_logprobs_by_datum: list[torch.Tensor] = [] + masks_by_datum: list[torch.Tensor] = [] + advantages_by_datum: list[torch.Tensor] = [] + for datum in datums: + target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch() + assert target_tokens.numel() > 0 + full_sequences.append( + datum.model_input.append_int(int(target_tokens[-1].item())) + ) + sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch()) + masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float()) + advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch()) + + reference_logprobs_by_datum = await asyncio.gather( + *[ + reference_sampling_client.compute_logprobs_async(full_sequence) + for full_sequence in full_sequences + ] + ) + + logprob_diffs_by_datum: list[torch.Tensor] = [] + for reference_logprobs, sampled_logprobs, mask in zip( + reference_logprobs_by_datum, + sampled_logprobs_by_datum, + masks_by_datum, + strict=True, + ): + reference_values = reference_logprobs[1:] + assert len(reference_values) == sampled_logprobs.numel() + assert all(value is not None for value in reference_values) + reference_logprobs_tensor = torch.tensor( + reference_values, + dtype=sampled_logprobs.dtype, + ) + logprob_diffs_by_datum.append( + (sampled_logprobs - reference_logprobs_tensor) * mask + ) + + total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum() + assert total_tokens.item() > 0 + avg_logprob_diff = ( + torch.stack( + [logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum] + ).sum() + / total_tokens + ) + + for datum, advantages, mask, logprob_diff in zip( + datums, + advantages_by_datum, + masks_by_datum, + logprob_diffs_by_datum, + strict=True, + ): + datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch( + advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask + ) + + return {"loss/kl_policy_ref": float(avg_logprob_diff)} + + @dataclass class ModelState: service_client: tinker.ServiceClient @@ -239,6 +310,9 @@ async def train( # type: ignore[override] save_checkpoint: bool = False, loss_fn_config: dict | None = None, adam_params: tinker.AdamParams | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_penalty_source: Literal["sample"] = "sample", ) -> TrainResult: state = self._model_state[model.name] groups_list = list(trajectory_groups) @@ -259,6 +333,10 @@ async def train( # type: ignore[override] "data/step_num_datums": float(len(datums)), } + assert kl_penalty_source == "sample", ( + "TinkerNativeBackend only supports kl_penalty_source='sample'." + ) + if not datums: return TrainResult(step=state.current_step, metrics=metrics) @@ -273,6 +351,23 @@ async def train( # type: ignore[override] ) trainer_started = time.monotonic() + if kl_penalty_coef > 0: + reference_sampling_client = await self._get_kl_reference_sampling_client( + state, + model.base_model, + kl_penalty_reference_step, + ) + metrics.update( + await self._tinker_sample_call( + "apply_kl_penalty", + _apply_kl_penalty( + datums, + reference_sampling_client, + kl_penalty_coef, + ), + ) + ) + if adam_params is None: adam_params = tinker.AdamParams( learning_rate=learning_rate, @@ -697,6 +792,19 @@ async def _get_sampler_client( state.sampler_clients[actual_step] = sampler_client return sampler_client + async def _get_kl_reference_sampling_client( + self, + state: ModelState, + base_model: str, + step: int | None, + ) -> tinker.SamplingClient: + if step is not None: + return await self._get_sampler_client(state, step) + return await self._tinker_sample_call( + "create_sampling_client_async", + state.service_client.create_sampling_client_async(base_model=base_model), + ) + def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] for message in messages: diff --git a/src/art/types.py b/src/art/types.py index 088041add..317fc156a 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -17,6 +17,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 kl_penalty_coef: float = 0.0 + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner" class TrainSFTConfig(pydantic.BaseModel): diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index e63fdb59a..7219e55ad 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -88,6 +88,44 @@ async def test_pipeline_trainer_preserves_backend_train_kwargs(tmp_path: Path) - } +@pytest.mark.asyncio +async def test_pipeline_trainer_forwards_kl_kwargs_for_generic_backend( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-generic-backend-kl-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_reference_step=7, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs == { + "learning_rate": 1e-5, + "loss_fn": "cispo", + "loss_fn_config": None, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": None, + "kl_penalty_coef": 0.25, + "kl_penalty_reference_step": 7, + "kl_penalty_source": "sample", + } + + @pytest.mark.asyncio async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( tmp_path: Path, @@ -165,6 +203,45 @@ async def fake_train_model( assert seen["dev_config"]["ppo"] is True +@pytest.mark.asyncio +async def test_local_backend_train_passes_kl_penalty_source(tmp_path: Path) -> None: + model = TrainableModel( + name="local-backend-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group([1.0])], + kl_penalty_coef=0.25, + kl_penalty_source="sample", + save_checkpoint=False, + ) + + assert result.step == 1 + assert seen["config"].kl_penalty_source == "sample" + assert seen["dev_config"]["kl_penalty_source"] == "sample" + + @pytest.mark.asyncio async def test_local_backend_async_context_manager_awaits_async_cleanup( tmp_path: Path, diff --git a/tests/unit/test_tinker_native_kl.py b/tests/unit/test_tinker_native_kl.py new file mode 100644 index 000000000..93991e24e --- /dev/null +++ b/tests/unit/test_tinker_native_kl.py @@ -0,0 +1,97 @@ +from typing import Any, cast + +import pytest +import tinker + +from art import TrainableModel +from art.tinker_native.backend import ( + ModelState, + TinkerNativeBackend, + _apply_kl_penalty, +) +from art.tinker_native.data import build_datum + + +class FakeSamplingClient: + def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None: + self._responses = responses + + async def compute_logprobs_async( + self, prompt: tinker.ModelInput + ) -> list[float | None]: + return self._responses[tuple(prompt.to_ints())] + + +@pytest.mark.asyncio +async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None: + datum_a = build_datum( + prompt_tokens=[101, 102], + completion_tokens=[201, 202], + logprobs=[-0.4, -0.8], + advantage=1.0, + ) + datum_b = build_datum( + prompt_tokens=[301, 302], + completion_tokens=[401], + logprobs=[-0.2], + advantage=2.0, + ) + assert datum_a is not None + assert datum_b is not None + + sampling_client = FakeSamplingClient( + { + (101, 102, 201, 202): [None, -9.0, -0.1, -0.5], + (301, 302, 401): [None, -7.0, -0.05], + } + ) + + metrics = await _apply_kl_penalty( + [datum_a, datum_b], + sampling_client, # type: ignore[arg-type] + kl_penalty_coef=2.0, + ) + + assert metrics == {"loss/kl_policy_ref": pytest.approx(-0.25)} + assert datum_a.loss_fn_inputs["advantages"].tolist() == pytest.approx( + [0.0, 1.1, 1.1] + ) + assert datum_b.loss_fn_inputs["advantages"].tolist() == pytest.approx([0.0, 1.8]) + + +@pytest.mark.asyncio +async def test_tinker_native_backend_rejects_current_learner_kl_source( + tmp_path, +) -> None: + backend = TinkerNativeBackend(tinker_api_key="test-key", path=str(tmp_path)) + model = TrainableModel( + name="tinker-native-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend._model_state[model.name] = ModelState( + service_client=cast(Any, object()), + rest_client=cast(Any, object()), + training_client=cast(Any, object()), + sampler_clients={}, + sampler_checkpoint_paths={}, + training_checkpoint_paths={}, + current_step=0, + renderer=cast(Any, object()), + tokenizer=cast(Any, object()), + output_dir=str(tmp_path), + tinker_run_ids=[], + model_name=model.name, + ) + + with pytest.raises( + AssertionError, + match="only supports kl_penalty_source='sample'", + ): + await cast(Any, backend).train( + model, + [], + kl_penalty_coef=0.25, + kl_penalty_source="current_learner", + ) From e113b3bd0683d7f6f8bb2e46f1b8724304738558 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Thu, 19 Mar 2026 16:55:33 -0700 Subject: [PATCH 09/13] refactor: Validate TinkerNative KL source before state lookup --- src/art/tinker_native/backend.py | 8 ++++---- tests/unit/test_tinker_native_kl.py | 30 +++++------------------------ 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index fe64e425a..949ad7044 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -314,6 +314,10 @@ async def train( # type: ignore[override] kl_penalty_reference_step: int | None = None, kl_penalty_source: Literal["sample"] = "sample", ) -> TrainResult: + assert kl_penalty_source == "sample", ( + "TinkerNativeBackend only supports kl_penalty_source='sample'." + ) + state = self._model_state[model.name] groups_list = list(trajectory_groups) summary = summarize_trajectory_groups(groups_list) @@ -333,10 +337,6 @@ async def train( # type: ignore[override] "data/step_num_datums": float(len(datums)), } - assert kl_penalty_source == "sample", ( - "TinkerNativeBackend only supports kl_penalty_source='sample'." - ) - if not datums: return TrainResult(step=state.current_step, metrics=metrics) diff --git a/tests/unit/test_tinker_native_kl.py b/tests/unit/test_tinker_native_kl.py index 93991e24e..a2d16d01f 100644 --- a/tests/unit/test_tinker_native_kl.py +++ b/tests/unit/test_tinker_native_kl.py @@ -1,18 +1,12 @@ -from typing import Any, cast - import pytest import tinker from art import TrainableModel -from art.tinker_native.backend import ( - ModelState, - TinkerNativeBackend, - _apply_kl_penalty, -) +from art.tinker_native.backend import TinkerNativeBackend, _apply_kl_penalty from art.tinker_native.data import build_datum -class FakeSamplingClient: +class FakeSamplingClient(tinker.SamplingClient): def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None: self._responses = responses @@ -48,7 +42,7 @@ async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None: metrics = await _apply_kl_penalty( [datum_a, datum_b], - sampling_client, # type: ignore[arg-type] + sampling_client, kl_penalty_coef=2.0, ) @@ -70,28 +64,14 @@ async def test_tinker_native_backend_rejects_current_learner_kl_source( base_model="test-model", base_path=str(tmp_path), ) - backend._model_state[model.name] = ModelState( - service_client=cast(Any, object()), - rest_client=cast(Any, object()), - training_client=cast(Any, object()), - sampler_clients={}, - sampler_checkpoint_paths={}, - training_checkpoint_paths={}, - current_step=0, - renderer=cast(Any, object()), - tokenizer=cast(Any, object()), - output_dir=str(tmp_path), - tinker_run_ids=[], - model_name=model.name, - ) with pytest.raises( AssertionError, match="only supports kl_penalty_source='sample'", ): - await cast(Any, backend).train( + await backend.train( model, [], kl_penalty_coef=0.25, - kl_penalty_source="current_learner", + kl_penalty_source="current_learner", # ty:ignore[invalid-argument-type] ) From 79bf8d539931b9fd1a251c201a2f00152c162941 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Thu, 19 Mar 2026 17:22:16 -0700 Subject: [PATCH 10/13] test: Add PipelineTrainer KL smoke coverage --- .../test_pipeline_localbackend_dedicated.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py index d6d04bc7b..11fab51d5 100644 --- a/tests/integration/test_pipeline_localbackend_dedicated.py +++ b/tests/integration/test_pipeline_localbackend_dedicated.py @@ -1,7 +1,10 @@ """Dedicated LocalBackend smoke test for PipelineTrainer.""" import asyncio +import json +import math import os +from pathlib import Path import tempfile import uuid @@ -163,6 +166,8 @@ async def rollout_fn( min_batch_size=1, max_batch_size=1, max_steps=2, + kl_penalty_coef=0.25, + kl_penalty_reference_step=0, loss_fn="cispo", eval_fn=None, ) @@ -180,5 +185,23 @@ async def rollout_fn( model_ids = [m.id async for m in client.models.list()] assert f"{model.name}@0" in model_ids assert f"{model.name}@{latest_step}" in model_ids + + history_path = ( + Path(tmpdir) + / model.project + / "models" + / model.name + / "history.jsonl" + ) + history_rows = [ + json.loads(line) for line in history_path.read_text().splitlines() + ] + kl_values = [ + row["loss/kl_policy_ref"] + for row in history_rows + if "loss/kl_policy_ref" in row + ] + assert kl_values + assert all(math.isfinite(value) for value in kl_values) finally: await client.close() From 5ef9eaba96b94e5a56734cf48d089a9b12757e52 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Thu, 19 Mar 2026 17:22:31 -0700 Subject: [PATCH 11/13] fix: Preserve sampled KL metric in TinkerNativeBackend --- src/art/tinker_native/backend.py | 36 +++++--- .../integration/test_tinker_native_backend.py | 83 +++++++++++++++++++ 2 files changed, 106 insertions(+), 13 deletions(-) diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index 949ad7044..65f59ca6a 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -350,23 +350,23 @@ async def train( # type: ignore[override] train_tokens, pricing ) trainer_started = time.monotonic() + sampled_kl_policy_ref: float | None = None if kl_penalty_coef > 0: - reference_sampling_client = await self._get_kl_reference_sampling_client( - state, - model.base_model, - kl_penalty_reference_step, - ) - metrics.update( - await self._tinker_sample_call( - "apply_kl_penalty", - _apply_kl_penalty( - datums, - reference_sampling_client, - kl_penalty_coef, + kl_metrics = await self._tinker_sample_call( + "apply_kl_penalty", + _apply_kl_penalty( + datums, + await self._get_kl_reference_sampling_client( + state, + model.base_model, + kl_penalty_reference_step, ), - ) + kl_penalty_coef, + ), ) + sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"] + metrics.update(kl_metrics) if adam_params is None: adam_params = tinker.AdamParams( @@ -405,6 +405,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) if optim_output.metrics: @@ -412,6 +417,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) diff --git a/tests/integration/test_tinker_native_backend.py b/tests/integration/test_tinker_native_backend.py index 09ff33c47..4e5c61a59 100644 --- a/tests/integration/test_tinker_native_backend.py +++ b/tests/integration/test_tinker_native_backend.py @@ -9,6 +9,8 @@ import art from art.tinker_native import TinkerNativeBackend +from art.tinker_native.backend import _apply_kl_penalty +from art.tinker_native.data import trajectory_groups_to_datums DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" @@ -37,6 +39,8 @@ async def simple_rollout( max_tokens=10, timeout=60, temperature=1, + logprobs=True, + top_logprobs=0, ) choice = chat_completion.choices[0] content = (choice.message.content or "").lower() @@ -115,6 +119,85 @@ async def make_group(prompt: str) -> art.TrajectoryGroup: await backend.close() +@pytest.mark.skipif( + "TINKER_API_KEY" not in os.environ, + reason="TINKER_API_KEY not set - skipping TinkerNativeBackend KL test", +) +async def test_tinker_native_backend_kl_identity_metric(): + model_name = f"test-tinker-native-kl-{uuid.uuid4().hex[:8]}" + with tempfile.TemporaryDirectory() as tmpdir: + backend = TinkerNativeBackend(path=tmpdir) + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + ) + try: + await model.register(backend) + + openai_client = model.openai_client() + current_step = await model.get_step() + model_name_step = model.get_inference_name(step=current_step) + prompts = ["Say yes", "Say no", "Say maybe"] + + async def make_group(prompt: str) -> art.TrajectoryGroup: + import asyncio + + trajectories = await asyncio.gather( + *[ + simple_rollout(openai_client, model_name_step, prompt) + for _ in range(2) + ] + ) + return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined] + + train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined] + [make_group(prompt) for prompt in prompts] + ) + ensure_reward_variance(train_groups) + + state = backend._model_state[model.name] + datums = trajectory_groups_to_datums( + train_groups, + state.renderer, + state.tokenizer, + ) + assert datums + + reference_sampling_client = await backend._get_kl_reference_sampling_client( + state, + model.base_model, + current_step, + ) + expected_kl = ( + await _apply_kl_penalty( + trajectory_groups_to_datums( + train_groups, + state.renderer, + state.tokenizer, + ), + reference_sampling_client, + kl_penalty_coef=0.25, + ) + )["loss/kl_policy_ref"] + + result = await backend.train( + model, + train_groups, + learning_rate=1e-5, + kl_penalty_coef=0.25, + kl_penalty_reference_step=current_step, + ) + + assert result.metrics["loss/kl_policy_ref"] == pytest.approx( + expected_kl, + abs=0.05, + ) + assert result.metrics["loss/kl_policy_ref"] == pytest.approx(0.0, abs=0.05) + finally: + await backend.close() + + @pytest.mark.skipif( "TINKER_API_KEY" not in os.environ, reason="TINKER_API_KEY not set - skipping TinkerNativeBackend fork test", From 644f1eaa2341da8b18c1711cffb2efb3fa9cb29e Mon Sep 17 00:00:00 2001 From: Angky William Date: Thu, 19 Mar 2026 21:31:31 -0700 Subject: [PATCH 12/13] feat: Replace kl_penalty_reference_step with kl_penalty_step_lag in PipelineTrainer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename parameter from `kl_penalty_reference_step` to `kl_penalty_step_lag` - `None` (default): uses step 0 as KL reference (anchor to initial model) - `>= 1`: uses `max(0, current_step - lag)` as reference (rolling anchor) - Add validation that kl_penalty_step_lag must be >= 1 if specified - Update existing tests and add new tests for lag computation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/art/pipeline_trainer/trainer.py | 14 ++-- .../test_pipeline_localbackend_dedicated.py | 1 - .../test_pipeline_trainer_local_backend.py | 69 ++++++++++++++++++- 3 files changed, 76 insertions(+), 8 deletions(-) diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index a50e6d57a..2522cc58e 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -79,7 +79,7 @@ def __init__( normalize_advantages: bool = True, adam_params: object | None = None, kl_penalty_coef: float = 0.0, - kl_penalty_reference_step: int | None = None, + kl_penalty_step_lag: int | None = None, max_steps: int | None = None, # Discard handling discard_queue_multiplier: int = 100, @@ -114,6 +114,8 @@ def __init__( raise ValueError("log_interval_seconds must be > 0") if discard_queue_multiplier <= 0: raise ValueError("discard_queue_multiplier must be > 0") + if kl_penalty_step_lag is not None and kl_penalty_step_lag < 1: + raise ValueError("kl_penalty_step_lag must be >= 1") self.model = model self.backend = backend self.rollout_fn = rollout_fn @@ -132,7 +134,7 @@ def __init__( self.normalize_advantages = normalize_advantages self.adam_params = adam_params self.kl_penalty_coef = kl_penalty_coef - self.kl_penalty_reference_step = kl_penalty_reference_step + self.kl_penalty_step_lag = kl_penalty_step_lag self.max_steps = max_steps self._status_log_interval_seconds = log_interval_seconds self.eval_every_n_steps = eval_every_n_steps @@ -460,9 +462,11 @@ async def _training_stage(self) -> None: if self.kl_penalty_coef > 0.0: kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef kl_train_kwargs["kl_penalty_source"] = "sample" - if self.kl_penalty_reference_step is not None: - kl_train_kwargs["kl_penalty_reference_step"] = ( - self.kl_penalty_reference_step + if self.kl_penalty_step_lag is None: + kl_train_kwargs["kl_penalty_reference_step"] = 0 + else: + kl_train_kwargs["kl_penalty_reference_step"] = max( + 0, current_step - self.kl_penalty_step_lag ) result = await self.backend.train( self.model, diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py index 11fab51d5..f1154a33f 100644 --- a/tests/integration/test_pipeline_localbackend_dedicated.py +++ b/tests/integration/test_pipeline_localbackend_dedicated.py @@ -167,7 +167,6 @@ async def rollout_fn( max_batch_size=1, max_steps=2, kl_penalty_coef=0.25, - kl_penalty_reference_step=0, loss_fn="cispo", eval_fn=None, ) diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index 7219e55ad..d17af12dd 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -105,7 +105,6 @@ async def test_pipeline_trainer_forwards_kl_kwargs_for_generic_backend( model=model, backend=backend, kl_penalty_coef=0.25, - kl_penalty_reference_step=7, ) trainer._output_queue = asyncio.Queue() await trainer._output_queue.put(_make_group([0.0, 1.0])) @@ -121,11 +120,77 @@ async def test_pipeline_trainer_forwards_kl_kwargs_for_generic_backend( "save_checkpoint": False, "adam_params": None, "kl_penalty_coef": 0.25, - "kl_penalty_reference_step": 7, + "kl_penalty_reference_step": 0, "kl_penalty_source": "sample", } +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_step_lag_floors_at_zero( + tmp_path: Path, +) -> None: + """kl_penalty_step_lag floors at step 0 when lag > current_step.""" + model = TrainableModel( + name="pipeline-kl-step-lag-floor", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=2, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_step_lag=5, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + # Simulate being at step 1 + trainer.state.next_training_step = 1 + + await trainer._training_stage() + + # At step 1, lag=5 → reference = max(0, 1-5) = 0 + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 0 + + +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_step_lag_computes_reference( + tmp_path: Path, +) -> None: + """kl_penalty_step_lag computes reference from current step.""" + model = TrainableModel( + name="pipeline-kl-step-lag", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=4, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_step_lag=2, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + # Simulate being at step 3 + trainer.state.next_training_step = 3 + + await trainer._training_stage() + + # At step 3, lag=2 → reference = max(0, 3-2) = 1 + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 1 + + @pytest.mark.asyncio async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( tmp_path: Path, From ae9e463b65956c11aa8b587edfcb2742db992088 Mon Sep 17 00:00:00 2001 From: Angky William Date: Fri, 20 Mar 2026 15:41:28 -0700 Subject: [PATCH 13/13] feat: Add TinkerNativeBackend yes-no-maybe KL advantage script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tinker variant of KL-penalized advantage training script and align model naming conventions (backend-random-coef) across both. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dev/run_yes_no_maybe_kl_advantage_tinker.py | 104 ++++++++++++++++++ dev/yes-no-maybe-kl-advantage-tinker.py | 111 ++++++++++++++++++++ dev/yes-no-maybe-kl-advantage.py | 5 +- 3 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 dev/run_yes_no_maybe_kl_advantage_tinker.py create mode 100644 dev/yes-no-maybe-kl-advantage-tinker.py diff --git a/dev/run_yes_no_maybe_kl_advantage_tinker.py b/dev/run_yes_no_maybe_kl_advantage_tinker.py new file mode 100644 index 000000000..468001d30 --- /dev/null +++ b/dev/run_yes_no_maybe_kl_advantage_tinker.py @@ -0,0 +1,104 @@ +"""Launch yes-no-maybe-kl-advantage-tinker training on SkyPilot (Kubernetes). + +Usage: + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --fast + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --base-model Qwen/Qwen2.5-7B-Instruct +""" + +import argparse +import os +import textwrap + +from dotenv import load_dotenv +import sky +from sky import ClusterStatus + +load_dotenv() + +parser = argparse.ArgumentParser( + description="Launch yes-no-maybe KL advantage training (Tinker) on SkyPilot." +) +parser.add_argument( + "--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)." +) +parser.add_argument( + "--base-model", type=str, default="meta-llama/Llama-3.1-8B-Instruct" +) +parser.add_argument("--num-steps", type=int, default=20) +parser.add_argument("--kl-penalty-coef", type=float, default=0.1) +parser.add_argument("--accelerator", type=str, default="H200:1") +parser.add_argument("--cluster-name", type=str, default=None) +parser.add_argument( + "--kl-ref-step", + type=int, + default=None, + help="Checkpoint step of training model to use as KL reference", +) +args = parser.parse_args() + +cluster_name = args.cluster_name or f"ynm-tinker-kl-{args.kl_penalty_coef}" +cluster_prefix = os.environ.get("CLUSTER_PREFIX") +if cluster_prefix: + cluster_name = f"{cluster_prefix}-{cluster_name}" + +setup_script = textwrap.dedent("""\ + echo 'Setting up environment...' + apt install -y nvtop + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +""") + +kl_ref_env = "" +if args.kl_ref_step is not None: + kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} " + +run_script = textwrap.dedent(f"""\ + source $HOME/.local/bin/env + cd ~/sky_workdir + {kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra tinker dev/yes-no-maybe-kl-advantage-tinker.py +""") + +task = sky.Task( + name="yes-no-maybe-kl-advantage-tinker", + setup=setup_script, + run=run_script, + workdir=".", +) +task.set_resources( + sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes()) +) +task.set_file_mounts( + { + "~/sky_workdir/.env": ".env", + } +) + +print(f"Launching on cluster: {cluster_name}") +print(f" base_model: {args.base_model}") +print(f" accelerator: {args.accelerator}") +print(f" num_steps: {args.num_steps}") +print(f" kl_penalty_coef: {args.kl_penalty_coef}") +if args.kl_ref_step is not None: + print(f" kl_ref_step: {args.kl_ref_step}") + +# Cancel any existing jobs on this cluster +cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name])) +if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP: + print(f"Cluster {cluster_name} is UP. Canceling any active jobs...") + sky.stream_and_get(sky.cancel(cluster_name, all=True)) + +job_id, _ = sky.stream_and_get( + sky.launch( + task, + cluster_name=cluster_name, + retry_until_up=True, + idle_minutes_to_autostop=60, + down=True, + fast=args.fast, + ) +) + +print(f"Job submitted (ID: {job_id}). Streaming logs...") +exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) +print(f"Job {job_id} finished with exit code {exit_code}.") diff --git a/dev/yes-no-maybe-kl-advantage-tinker.py b/dev/yes-no-maybe-kl-advantage-tinker.py new file mode 100644 index 000000000..5983a1f2d --- /dev/null +++ b/dev/yes-no-maybe-kl-advantage-tinker.py @@ -0,0 +1,111 @@ +"""Yes-no-maybe training with KL-penalized advantage adjustment (Tinker backend). + +Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted +more from the reference model get reduced advantages, while tokens that have +drifted less get increased advantages. + +Uses meta-llama/Llama-3.1-8B-Instruct as the base model (trained via Tinker). +""" + +import asyncio +from itertools import permutations +import os +import random +import string + +from dotenv import load_dotenv +import openai + +import art +from art.tinker_native import TinkerNativeBackend + + +async def rollout( + client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str +) -> art.Trajectory: + messages: art.Messages = [ + { + "role": "user", + "content": prompt, + } + ] + chat_completion = await client.chat.completions.create( + messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100 + ) + choice = chat_completion.choices[0] + content = choice.message.content + assert isinstance(content, str) + if content == "yes": + reward = 0.5 + elif content == "no": + reward = 0.75 + elif content == "maybe": + reward = 1.0 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +def with_quotes(w: str) -> str: + return f"'{w}'" + + +async def main(): + load_dotenv() + + backend = TinkerNativeBackend() + base_model = os.environ.get("BASE_MODEL", "meta-llama/Llama-3.1-8B-Instruct") + kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + random_suffix = "".join(random.choices(string.ascii_lowercase, k=4)) + model = art.TrainableModel( + name=os.environ.get("MODEL_NAME", f"tinker-{random_suffix}-{kl_penalty_coef}"), + project="yes-no-maybe", + base_model=base_model, + ) + await model.register(backend) + + kl_penalty_reference_step: int | None = ( + int(os.environ["KL_REF_STEP"]) + if os.environ.get("KL_REF_STEP") is not None + else None + ) + + prompts = [ + f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" + for prefix in ["respond", "just respond"] + for use_quotes in [True, False] + for words in ( + list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n) + ) + ] + + openai_client = model.openai_client() + max_steps = int(os.environ.get("NUM_STEPS", "20")) + start_step = await model.get_step() + for step in range(start_step, start_step + max_steps): + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(openai_client, model, prompt) for _ in range(32) + ) + for prompt in prompts + ) + ) + result = await backend.train( + model, + train_groups, + learning_rate=1e-4, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_reference_step=kl_penalty_reference_step, + ) + await model.log( + train_groups, + metrics=result.metrics, + step=result.step, + split="train", + ) + print(f"step {result.step}: {result.metrics}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev/yes-no-maybe-kl-advantage.py b/dev/yes-no-maybe-kl-advantage.py index 41ce0b119..ccd21b243 100644 --- a/dev/yes-no-maybe-kl-advantage.py +++ b/dev/yes-no-maybe-kl-advantage.py @@ -10,6 +10,8 @@ import asyncio from itertools import permutations import os +import random +import string from dotenv import load_dotenv import openai @@ -54,8 +56,9 @@ async def main(): backend = LocalBackend() base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + random_suffix = "".join(random.choices(string.ascii_lowercase, k=4)) model = art.TrainableModel( - name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"), + name=os.environ.get("MODEL_NAME", f"local-{random_suffix}-{kl_penalty_coef}"), project="yes-no-maybe", base_model=base_model, )