diff --git a/dev/run_yes_no_maybe_kl_advantage_tinker.py b/dev/run_yes_no_maybe_kl_advantage_tinker.py new file mode 100644 index 00000000..468001d3 --- /dev/null +++ b/dev/run_yes_no_maybe_kl_advantage_tinker.py @@ -0,0 +1,104 @@ +"""Launch yes-no-maybe-kl-advantage-tinker training on SkyPilot (Kubernetes). + +Usage: + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --fast + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --base-model Qwen/Qwen2.5-7B-Instruct +""" + +import argparse +import os +import textwrap + +from dotenv import load_dotenv +import sky +from sky import ClusterStatus + +load_dotenv() + +parser = argparse.ArgumentParser( + description="Launch yes-no-maybe KL advantage training (Tinker) on SkyPilot." +) +parser.add_argument( + "--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)." +) +parser.add_argument( + "--base-model", type=str, default="meta-llama/Llama-3.1-8B-Instruct" +) +parser.add_argument("--num-steps", type=int, default=20) +parser.add_argument("--kl-penalty-coef", type=float, default=0.1) +parser.add_argument("--accelerator", type=str, default="H200:1") +parser.add_argument("--cluster-name", type=str, default=None) +parser.add_argument( + "--kl-ref-step", + type=int, + default=None, + help="Checkpoint step of training model to use as KL reference", +) +args = parser.parse_args() + +cluster_name = args.cluster_name or f"ynm-tinker-kl-{args.kl_penalty_coef}" +cluster_prefix = os.environ.get("CLUSTER_PREFIX") +if cluster_prefix: + cluster_name = f"{cluster_prefix}-{cluster_name}" + +setup_script = textwrap.dedent("""\ + echo 'Setting up environment...' + apt install -y nvtop + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +""") + +kl_ref_env = "" +if args.kl_ref_step is not None: + kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} " + +run_script = textwrap.dedent(f"""\ + source $HOME/.local/bin/env + cd ~/sky_workdir + {kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra tinker dev/yes-no-maybe-kl-advantage-tinker.py +""") + +task = sky.Task( + name="yes-no-maybe-kl-advantage-tinker", + setup=setup_script, + run=run_script, + workdir=".", +) +task.set_resources( + sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes()) +) +task.set_file_mounts( + { + "~/sky_workdir/.env": ".env", + } +) + +print(f"Launching on cluster: {cluster_name}") +print(f" base_model: {args.base_model}") +print(f" accelerator: {args.accelerator}") +print(f" num_steps: {args.num_steps}") +print(f" kl_penalty_coef: {args.kl_penalty_coef}") +if args.kl_ref_step is not None: + print(f" kl_ref_step: {args.kl_ref_step}") + +# Cancel any existing jobs on this cluster +cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name])) +if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP: + print(f"Cluster {cluster_name} is UP. Canceling any active jobs...") + sky.stream_and_get(sky.cancel(cluster_name, all=True)) + +job_id, _ = sky.stream_and_get( + sky.launch( + task, + cluster_name=cluster_name, + retry_until_up=True, + idle_minutes_to_autostop=60, + down=True, + fast=args.fast, + ) +) + +print(f"Job submitted (ID: {job_id}). Streaming logs...") +exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) +print(f"Job {job_id} finished with exit code {exit_code}.") diff --git a/dev/yes-no-maybe-kl-advantage-tinker.py b/dev/yes-no-maybe-kl-advantage-tinker.py new file mode 100644 index 00000000..5983a1f2 --- /dev/null +++ b/dev/yes-no-maybe-kl-advantage-tinker.py @@ -0,0 +1,111 @@ +"""Yes-no-maybe training with KL-penalized advantage adjustment (Tinker backend). + +Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted +more from the reference model get reduced advantages, while tokens that have +drifted less get increased advantages. + +Uses meta-llama/Llama-3.1-8B-Instruct as the base model (trained via Tinker). +""" + +import asyncio +from itertools import permutations +import os +import random +import string + +from dotenv import load_dotenv +import openai + +import art +from art.tinker_native import TinkerNativeBackend + + +async def rollout( + client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str +) -> art.Trajectory: + messages: art.Messages = [ + { + "role": "user", + "content": prompt, + } + ] + chat_completion = await client.chat.completions.create( + messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100 + ) + choice = chat_completion.choices[0] + content = choice.message.content + assert isinstance(content, str) + if content == "yes": + reward = 0.5 + elif content == "no": + reward = 0.75 + elif content == "maybe": + reward = 1.0 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +def with_quotes(w: str) -> str: + return f"'{w}'" + + +async def main(): + load_dotenv() + + backend = TinkerNativeBackend() + base_model = os.environ.get("BASE_MODEL", "meta-llama/Llama-3.1-8B-Instruct") + kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + random_suffix = "".join(random.choices(string.ascii_lowercase, k=4)) + model = art.TrainableModel( + name=os.environ.get("MODEL_NAME", f"tinker-{random_suffix}-{kl_penalty_coef}"), + project="yes-no-maybe", + base_model=base_model, + ) + await model.register(backend) + + kl_penalty_reference_step: int | None = ( + int(os.environ["KL_REF_STEP"]) + if os.environ.get("KL_REF_STEP") is not None + else None + ) + + prompts = [ + f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" + for prefix in ["respond", "just respond"] + for use_quotes in [True, False] + for words in ( + list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n) + ) + ] + + openai_client = model.openai_client() + max_steps = int(os.environ.get("NUM_STEPS", "20")) + start_step = await model.get_step() + for step in range(start_step, start_step + max_steps): + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(openai_client, model, prompt) for _ in range(32) + ) + for prompt in prompts + ) + ) + result = await backend.train( + model, + train_groups, + learning_rate=1e-4, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_reference_step=kl_penalty_reference_step, + ) + await model.log( + train_groups, + metrics=result.metrics, + step=result.step, + split="train", + ) + print(f"step {result.step}: {result.metrics}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev/yes-no-maybe-kl-advantage.py b/dev/yes-no-maybe-kl-advantage.py index 41ce0b11..ccd21b24 100644 --- a/dev/yes-no-maybe-kl-advantage.py +++ b/dev/yes-no-maybe-kl-advantage.py @@ -10,6 +10,8 @@ import asyncio from itertools import permutations import os +import random +import string from dotenv import load_dotenv import openai @@ -54,8 +56,9 @@ async def main(): backend = LocalBackend() base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + random_suffix = "".join(random.choices(string.ascii_lowercase, k=4)) model = art.TrainableModel( - name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"), + name=os.environ.get("MODEL_NAME", f"local-{random_suffix}-{kl_penalty_coef}"), project="yes-no-maybe", base_model=base_model, ) diff --git a/docs/features/checkpoint-forking.mdx b/docs/features/checkpoint-forking.mdx index c3d3603d..4a214c85 100644 --- a/docs/features/checkpoint-forking.mdx +++ b/docs/features/checkpoint-forking.mdx @@ -31,7 +31,7 @@ import art from art.local import LocalBackend async def train(): - with LocalBackend() as backend: + async with LocalBackend() as backend: # Create a new model that will fork from an existing checkpoint model = art.TrainableModel( name="my-model-v2", @@ -115,7 +115,7 @@ low_lr_model = art.TrainableModel( ) async def experiment(): - with LocalBackend() as backend: + async with LocalBackend() as backend: # Fork the model from the base model await backend._experimental_fork_checkpoint( low_lr_model, diff --git a/docs/fundamentals/art-backend.mdx b/docs/fundamentals/art-backend.mdx index f65473a9..9e6019c0 100644 --- a/docs/fundamentals/art-backend.mdx +++ b/docs/fundamentals/art-backend.mdx @@ -73,6 +73,30 @@ backend = LocalBackend( ) ``` +If you're using `PipelineTrainer`, `LocalBackend` is currently supported only in dedicated mode, where training and inference run on separate GPUs. + +```python +from art import TrainableModel +from art.dev import InternalModelConfig +from art.local import LocalBackend + +backend = LocalBackend(path="./.art") +model = TrainableModel( + name="pipeline-localbackend", + project="my-project", + base_model="Qwen/Qwen3-0.6B", + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), +) +``` + +Shared `LocalBackend` still pauses inference during training, so ART rejects that configuration for `PipelineTrainer`. + +In dedicated mode, a new checkpoint becomes the default inference target only after its LoRA has been reloaded into vLLM. That checkpoint publication flow is backend-specific, so `save_checkpoint` does not have identical semantics across every ART backend. +Requests that are already in flight keep using the adapter they started with; the reload only affects subsequent routing to the latest served step. + ## Using a backend Once initialized, a backend can be used in the same way regardless of whether it runs locally or remotely. diff --git a/docs/fundamentals/training-loop.mdx b/docs/fundamentals/training-loop.mdx index b7bc8fe9..4c8f4f75 100644 --- a/docs/fundamentals/training-loop.mdx +++ b/docs/fundamentals/training-loop.mdx @@ -22,6 +22,8 @@ ART's functionality is divided into a [**client**](/fundamentals/art-client) and This training loop runs until a specified number of inference and training iterations have completed. +This describes the default shared-resource loop. `PipelineTrainer` can also run with `LocalBackend` in dedicated mode, where training and inference stay on separate GPUs and the latest served step advances only after vLLM reloads the new LoRA. + Training and inference use both the ART **client** and **backend**. Learn more by following the links below!
diff --git a/src/art/dev/train.py b/src/art/dev/train.py index b0e232c5..5da3e1ab 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False): ] kimi_k2_tau: float | None kl_penalty_coef: float + kl_penalty_source: Literal["current_learner", "sample"] kl_ref_adapter_path: str | None logprob_calculation_chunk_size: int mask_prob_ratio: bool diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 5baf200f..723383ad 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -159,6 +159,9 @@ def _allocated_gpu_count(self, model: Model) -> int: def __enter__(self) -> Self: return self + async def __aenter__(self) -> Self: + return self + def __exit__( self, exc_type: type[BaseException] | None, @@ -167,14 +170,30 @@ def __exit__( ) -> None: self._close() + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.close() + async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - self._close() + for service in self._services.values(): + aclose = getattr(service, "aclose", None) + if aclose is None: + close = getattr(service, "close", None) + if close is not None: + close() + else: + await aclose() + close_proxy(service) def _close(self) -> None: - for _, service in self._services.items(): + for service in self._services.values(): close = getattr(service, "close", None) if close is not None: close() @@ -219,25 +238,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str: If None, returns name for latest checkpoint (step 0 initially). """ - # For LocalBackend, vLLM always serves LoRA adapters with @step suffix - # Default to step 0 when not specified (the initial checkpoint created at registration) - if step is not None: - actual_step = step - elif model.name in self._services and self._in_process: - # In dedicated mode the service tracks which adapter vLLM has - # actually loaded. Reading the filesystem would race: the - # checkpoint directory appears before the HTTP reload completes. - svc = self._services[model.name] - loaded_step = getattr(svc, "_latest_step", None) - actual_step = ( - loaded_step if loaded_step is not None else self.__get_step(model) - ) - else: - actual_step = self.__get_step(model) - name = f"{model.name}@{actual_step}" + requested_step = step + + if step is None and isinstance(model, TrainableModel): + from ..dev.validate import is_dedicated_mode + + service = self._services.get(model.name) + if service is not None and is_dedicated_mode( + model._internal_config or dev.InternalModelConfig() + ): + loaded_step = getattr(service, "_latest_step", None) + if isinstance(loaded_step, int): + step = loaded_step + + if step is None: + # The checkpoint directory is written before dedicated-mode + # vLLM finishes reloading the new adapter. + step = self.__get_step(model) + name = f"{model.name}@{step}" logger.debug( - f"[BACKEND] _model_inference_name: step_arg={step} " - f"actual_step={actual_step} -> {name}" + f"[BACKEND] _model_inference_name: step_arg={requested_step} " + f"actual_step={step} -> {name}" ) return name @@ -502,12 +523,15 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, + loss_fn: Literal["cispo", "ppo"] = "cispo", + loss_fn_config: dict | None = None, + normalize_advantages: bool = True, + adam_params: object | None = None, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, - # RL algorithm settings - ppo: bool = False, + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner", epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -544,6 +568,14 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. + loss_fn: RL loss function. LocalBackend currently supports + "cispo" and "ppo". + loss_fn_config: Additional loss-function config. Not supported by + LocalBackend. + normalize_advantages: Whether to normalize advantages. LocalBackend + currently requires True. + adam_params: Custom optimizer params. Not supported by + LocalBackend. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. Tokens diverging more from the reference get reduced advantages. Defaults to 0.0 (disabled). @@ -553,8 +585,12 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. - ppo: Whether to use PPO clipping. Defaults to False. - epsilon: Clip epsilon for importance sampling. Defaults based on ppo. + kl_penalty_source: Which policy's logprobs to compare against the + reference when building the centered KL penalty. Use + "current_learner" to match the original ART implementation, or + "sample" to shape from the rollout policy logprobs, which is + usually better for async/off-policy workloads. + epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages in range [-1.0, 1.0]. Defaults to 0.0 (balanced). @@ -597,19 +633,31 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) + if loss_fn not in {"cispo", "ppo"}: + raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.") + if loss_fn_config is not None: + raise ValueError("LocalBackend requires loss_fn_config=None.") + if not normalize_advantages: + raise ValueError("LocalBackend requires normalize_advantages=True.") + if adam_params is not None: + raise ValueError("LocalBackend requires adam_params=None.") + assert kl_penalty_source in {"current_learner", "sample"} # Build config objects from explicit kwargs config = TrainConfig( - learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef + learning_rate=learning_rate, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_source=kl_penalty_source, ) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, "allow_training_without_logprobs": allow_training_without_logprobs, "importance_sampling_level": importance_sampling_level, "kl_penalty_coef": kl_penalty_coef, + "kl_penalty_source": kl_penalty_source, "mask_prob_ratio": mask_prob_ratio, "plot_tensors": plot_tensors, - "ppo": ppo, + "ppo": loss_fn == "ppo", "precalculate_logprobs": precalculate_logprobs, "scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev, "scale_rewards": scale_rewards, diff --git a/src/art/loss.py b/src/art/loss.py index 5a73d7b7..59cfa46a 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -95,7 +95,14 @@ def loss_fn( kl_policy_ref: torch.Tensor | None = None kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0) if kl_penalty_coef > 0 and ref_logprobs is not None: - kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask + match experimental_config.get("kl_penalty_source", "current_learner"): + case "sample": + kl_source_logprobs = old_logprobs.detach() + case "current_learner": + kl_source_logprobs = new_logprobs.detach() + case other: + raise AssertionError(other) + kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6) kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask advantages = advantages + kl_penalty diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 1458d153..2522cc58 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -78,6 +78,8 @@ def __init__( loss_fn_config: dict | None = None, normalize_advantages: bool = True, adam_params: object | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_step_lag: int | None = None, max_steps: int | None = None, # Discard handling discard_queue_multiplier: int = 100, @@ -112,6 +114,8 @@ def __init__( raise ValueError("log_interval_seconds must be > 0") if discard_queue_multiplier <= 0: raise ValueError("discard_queue_multiplier must be > 0") + if kl_penalty_step_lag is not None and kl_penalty_step_lag < 1: + raise ValueError("kl_penalty_step_lag must be >= 1") self.model = model self.backend = backend self.rollout_fn = rollout_fn @@ -129,6 +133,8 @@ def __init__( self.loss_fn_config = loss_fn_config self.normalize_advantages = normalize_advantages self.adam_params = adam_params + self.kl_penalty_coef = kl_penalty_coef + self.kl_penalty_step_lag = kl_penalty_step_lag self.max_steps = max_steps self._status_log_interval_seconds = log_interval_seconds self.eval_every_n_steps = eval_every_n_steps @@ -154,6 +160,7 @@ def __init__( total_scenarios=total_scenarios, num_workers=num_rollout_workers, ) + self._validate_backend_support() async def train(self, *, handle_signals: bool = True) -> None: """Run the training pipeline over the configured scenario iterator.""" @@ -277,6 +284,42 @@ async def _notify_policy() -> None: except asyncio.QueueFull: loop.create_task(self._output_queue.put(None)) + def _validate_backend_support(self) -> None: + from art.dev.validate import is_dedicated_mode + from art.local.backend import LocalBackend + + if not isinstance(self.backend, LocalBackend): + return + + model_config = self.model._internal_config or art.dev.InternalModelConfig() + if not is_dedicated_mode(model_config): + raise ValueError( + "PipelineTrainer only supports LocalBackend in dedicated mode. " + "Shared LocalBackend pauses inference during training and is not " + "a supported async PipelineTrainer path. Set both " + "trainer_gpu_ids and inference_gpu_ids on the TrainableModel " + "_internal_config to use LocalBackend with PipelineTrainer." + ) + if self.loss_fn not in {"cispo", "ppo"}: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) only supports " + "loss_fn='cispo' or loss_fn='ppo'." + ) + if self.loss_fn_config is not None: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) requires " + "loss_fn_config=None." + ) + if not self.normalize_advantages: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) requires " + "normalize_advantages=True." + ) + if self.adam_params is not None: + raise ValueError( + "PipelineTrainer + LocalBackend(dedicated) requires adam_params=None." + ) + async def _skip_scenarios( self, scenarios: AsyncIterator[ScenarioT], count: int ) -> int: @@ -415,6 +458,16 @@ async def _training_stage(self) -> None: if os.getenv("ART_TRAIN_STEP_LOG"): print(f"[train] step {expected_step} starting (batch={len(batch)})") try: + kl_train_kwargs: dict[str, object] = {} + if self.kl_penalty_coef > 0.0: + kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef + kl_train_kwargs["kl_penalty_source"] = "sample" + if self.kl_penalty_step_lag is None: + kl_train_kwargs["kl_penalty_reference_step"] = 0 + else: + kl_train_kwargs["kl_penalty_reference_step"] = max( + 0, current_step - self.kl_penalty_step_lag + ) result = await self.backend.train( self.model, batch, @@ -424,6 +477,7 @@ async def _training_stage(self) -> None: normalize_advantages=self.normalize_advantages, save_checkpoint=should_checkpoint, adam_params=self.adam_params, + **kl_train_kwargs, ) except Exception: self._status.note_training_end() @@ -515,7 +569,7 @@ async def _collect_batch( continue batch.append(item) - while not saw_sentinel: + while not saw_sentinel and len(batch) < self.max_batch_size: try: item = self._output_queue.get_nowait() except asyncio.QueueEmpty: diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py index d944efc6..82c0f2a2 100644 --- a/src/art/test/test_kl_advantage.py +++ b/src/art/test/test_kl_advantage.py @@ -2,7 +2,7 @@ import torch -from art.loss import Loss, loss_fn +from art.loss import loss_fn, shift_tensor def _make_inputs( @@ -114,3 +114,50 @@ def test_kl_advantage_does_not_affect_when_no_ref(): loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5}) assert loss.kl_policy_ref is None + + +def test_kl_advantage_can_use_sample_logprobs() -> None: + """Sample-source KL should use stored rollout logprobs rather than learner logprobs.""" + inputs = _make_inputs(seq_len=8) + inputs["logprobs"] = torch.tensor( + [[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32 + ) + new_logprobs = torch.tensor( + [[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32 + ) + ref_logprobs = torch.full((1, 8), -0.5) + assistant_mask = shift_tensor(inputs["assistant_mask"], False).to( + new_logprobs.dtype + ) + sampled_logprobs = torch.where( + torch.isnan(shift_tensor(inputs["logprobs"], float("nan"))), + new_logprobs.detach(), + shift_tensor(inputs["logprobs"], float("nan")), + ) + expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + + sample_loss = loss_fn( + inputs, + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"}, + ) + learner_loss = loss_fn( + inputs, + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"}, + ) + + assert sample_loss.kl_policy_ref is not None + assert learner_loss.kl_policy_ref is not None + assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl) + assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl) + assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref) diff --git a/src/art/test/test_step_skipping.py b/src/art/test/test_step_skipping.py index f4c85a1b..0a048b5a 100755 --- a/src/art/test/test_step_skipping.py +++ b/src/art/test/test_step_skipping.py @@ -44,7 +44,7 @@ async def test_step_skipping(): # Set up backend with custom art path art_path = os.path.join(tmpdir, ".art") - with LocalBackend(path=art_path) as backend: + async with LocalBackend(path=art_path) as backend: # Create a test model model = TrainableModel( name=f"test-step-skip-{uuid.uuid4()}", diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index c1687bf7..65f59ca6 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -24,6 +24,7 @@ from openai.types.chat.completion_create_params import CompletionCreateParams from openai.types.completion_usage import CompletionUsage import tinker +import torch import uvicorn from art.tinker.cookbook_v import renderers, tokenizer_utils @@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str: return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric) +async def _apply_kl_penalty( + datums: list[tinker.Datum], + reference_sampling_client: tinker.SamplingClient, + kl_penalty_coef: float, +) -> dict[str, float]: + assert datums + assert kl_penalty_coef > 0.0 + + full_sequences: list[tinker.ModelInput] = [] + sampled_logprobs_by_datum: list[torch.Tensor] = [] + masks_by_datum: list[torch.Tensor] = [] + advantages_by_datum: list[torch.Tensor] = [] + for datum in datums: + target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch() + assert target_tokens.numel() > 0 + full_sequences.append( + datum.model_input.append_int(int(target_tokens[-1].item())) + ) + sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch()) + masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float()) + advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch()) + + reference_logprobs_by_datum = await asyncio.gather( + *[ + reference_sampling_client.compute_logprobs_async(full_sequence) + for full_sequence in full_sequences + ] + ) + + logprob_diffs_by_datum: list[torch.Tensor] = [] + for reference_logprobs, sampled_logprobs, mask in zip( + reference_logprobs_by_datum, + sampled_logprobs_by_datum, + masks_by_datum, + strict=True, + ): + reference_values = reference_logprobs[1:] + assert len(reference_values) == sampled_logprobs.numel() + assert all(value is not None for value in reference_values) + reference_logprobs_tensor = torch.tensor( + reference_values, + dtype=sampled_logprobs.dtype, + ) + logprob_diffs_by_datum.append( + (sampled_logprobs - reference_logprobs_tensor) * mask + ) + + total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum() + assert total_tokens.item() > 0 + avg_logprob_diff = ( + torch.stack( + [logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum] + ).sum() + / total_tokens + ) + + for datum, advantages, mask, logprob_diff in zip( + datums, + advantages_by_datum, + masks_by_datum, + logprob_diffs_by_datum, + strict=True, + ): + datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch( + advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask + ) + + return {"loss/kl_policy_ref": float(avg_logprob_diff)} + + @dataclass class ModelState: service_client: tinker.ServiceClient @@ -239,7 +310,14 @@ async def train( # type: ignore[override] save_checkpoint: bool = False, loss_fn_config: dict | None = None, adam_params: tinker.AdamParams | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_penalty_source: Literal["sample"] = "sample", ) -> TrainResult: + assert kl_penalty_source == "sample", ( + "TinkerNativeBackend only supports kl_penalty_source='sample'." + ) + state = self._model_state[model.name] groups_list = list(trajectory_groups) summary = summarize_trajectory_groups(groups_list) @@ -272,6 +350,23 @@ async def train( # type: ignore[override] train_tokens, pricing ) trainer_started = time.monotonic() + sampled_kl_policy_ref: float | None = None + + if kl_penalty_coef > 0: + kl_metrics = await self._tinker_sample_call( + "apply_kl_penalty", + _apply_kl_penalty( + datums, + await self._get_kl_reference_sampling_client( + state, + model.base_model, + kl_penalty_reference_step, + ), + kl_penalty_coef, + ), + ) + sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"] + metrics.update(kl_metrics) if adam_params is None: adam_params = tinker.AdamParams( @@ -310,6 +405,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) if optim_output.metrics: @@ -317,6 +417,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) @@ -697,6 +802,19 @@ async def _get_sampler_client( state.sampler_clients[actual_step] = sampler_client return sampler_client + async def _get_kl_reference_sampling_client( + self, + state: ModelState, + base_model: str, + step: int | None, + ) -> tinker.SamplingClient: + if step is not None: + return await self._get_sampler_client(state, step) + return await self._tinker_sample_call( + "create_sampling_client_async", + state.service_client.create_sampling_client_async(base_model=base_model), + ) + def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] for message in messages: diff --git a/src/art/types.py b/src/art/types.py index 088041ad..317fc156 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -17,6 +17,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 kl_penalty_coef: float = 0.0 + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner" class TrainSFTConfig(pydantic.BaseModel): diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index f24be80f..5b6a563c 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -35,7 +35,7 @@ from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers -from .train import gc_and_empty_cuda_cache, train +from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train logger = logging.getLogger(__name__) @@ -55,6 +55,15 @@ class SupportsLoadLora(Protocol): def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ... +class _StopTrainInputs: + """Dedicated sentinel for stopping the background trainer loop.""" + + +_STOP_TRAIN_INPUT = _StopTrainInputs() +_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0 +_TrainLoopInput = TrainInputs | _StopTrainInputs + + def precalculate_new_logprobs( trainer: "GRPOTrainer", peft_model: "PeftModelForCausalLM", @@ -91,7 +100,7 @@ async def process_train_batch( packed_tensors: PackedTensors, config: types.TrainConfig, _config: dev.TrainConfig, - inputs_queue: asyncio.Queue[TrainInputs], + inputs_queue: asyncio.Queue[_TrainLoopInput], results_queue: asyncio.Queue[dict[str, float]], train_task: asyncio.Task[None], trainer: "GRPOTrainer", @@ -215,7 +224,7 @@ class UnslothState: tokenizer: PreTrainedTokenizerBase peft_model: peft.peft_model.PeftModelForCausalLM trainer: GRPOTrainer - inputs_queue: asyncio.Queue[TrainInputs] + inputs_queue: asyncio.Queue[_TrainLoopInput] results_queue: asyncio.Queue[dict[str, float]] _is_offloaded: bool = False _pinned_buffers: dict[str, torch.Tensor] | None = None @@ -316,6 +325,7 @@ class UnslothService: _vllm_log_file: Any = field(default=None, repr=False) _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 + _train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False) @property def is_dedicated(self) -> bool: @@ -326,6 +336,24 @@ def _next_lora_id(self) -> int: self._lora_id_counter += 1 return self._lora_id_counter + async def aclose(self) -> None: + train_task = self._train_task + self._train_task = None + if train_task is None or train_task.done(): + self.close() + return + + # `_state` is a cached_property. Read from __dict__ directly so + # closing does not instantiate trainer state only to stop a task. + state = self.__dict__.get("_state") + assert isinstance(state, UnslothState) + state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) + try: + await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S) + except asyncio.TimeoutError: + train_task.cancel() + self.close() + # ========================================================================= # Dedicated mode: vLLM subprocess lifecycle # ========================================================================= @@ -595,7 +623,7 @@ async def _train_dedicated( await self._state.results_queue.join() - if not hasattr(self, "_train_task") or self._train_task is None: + if self._train_task is None: self._train_task = asyncio.create_task( train( trainer=self._state.trainer, @@ -685,7 +713,7 @@ async def _train_shared( await self._state.results_queue.join() # If we haven't already, start the training task - if not hasattr(self, "_train_task") or self._train_task is None: + if self._train_task is None: self._train_task = asyncio.create_task( train( trainer=self._state.trainer, @@ -981,17 +1009,19 @@ def _state(self) -> UnslothState: trainer.create_optimizer() # Initialize queues - inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue() + inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue() results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() # Patch trainer _prepare_inputs() to pull from queue def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: - async def get_inputs() -> TrainInputs: + async def get_inputs() -> _TrainLoopInput: return await inputs_queue.get() # Force otherwise synchronous _prepare_inputs() to yield # with nested asyncio.run() call inputs = asyncio.run(get_inputs()) + if isinstance(inputs, _StopTrainInputs): + raise StopTrainingLoop() return cast(dict[str, torch.Tensor], inputs) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index f6c42a2c..43387904 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -39,6 +39,10 @@ } +class StopTrainingLoop(Exception): + """Signal that the background trainer loop should exit cleanly.""" + + def _canonicalize_upstream_metric_key(metric: str) -> str: if "/" in metric: return metric @@ -75,6 +79,8 @@ async def train( trainer._metrics = {"train": defaultdict(list)} try: trainer.train() + except StopTrainingLoop: + return finally: trainer.compute_loss = _compute_loss trainer.log = _log # ty:ignore[invalid-assignment] diff --git a/src/art/utils/benchmarking/pull_model_trajectories.py b/src/art/utils/benchmarking/pull_model_trajectories.py index b8a06bfe..2708d5cf 100644 --- a/src/art/utils/benchmarking/pull_model_trajectories.py +++ b/src/art/utils/benchmarking/pull_model_trajectories.py @@ -31,8 +31,9 @@ async def pull_model_trajectories(model: ArtModel) -> None: "Environment variable BACKUP_BUCKET is required but was not found." ) - # Use the LocalBackend context manager to work with the on-disk artefacts. - with LocalBackend() as backend: + # Use the LocalBackend async context manager so backend cleanup can await + # any background service shutdown before returning. + async with LocalBackend() as backend: print( f"Pulling trajectories for model '{model.name}' from S3 bucket '{bucket}'…", flush=True, diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py new file mode 100644 index 00000000..f1154a33 --- /dev/null +++ b/tests/integration/test_pipeline_localbackend_dedicated.py @@ -0,0 +1,206 @@ +"""Dedicated LocalBackend smoke test for PipelineTrainer.""" + +import asyncio +import json +import math +import os +from pathlib import Path +import tempfile +import uuid + +import openai +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("vllm") + +import art +from art.local import LocalBackend +from art.pipeline_trainer import PipelineTrainer + +DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" +DEFAULT_GPU_MEMORY_UTILIZATION = 0.2 +DEFAULT_MAX_MODEL_LEN = 2048 +DEFAULT_MAX_SEQ_LENGTH = 2048 + + +def get_base_model() -> str: + return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL) + + +def get_safe_gpu_memory_utilization() -> float: + requested = float( + os.environ.get( + "ART_TEST_GPU_MEMORY_UTILIZATION", + str(DEFAULT_GPU_MEMORY_UTILIZATION), + ) + ) + min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8")) + free_ratios: list[float] = [] + for device in (0, 1): + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + free_gib = free_bytes / (1024**3) + if free_gib < min_free_gib: + pytest.skip( + "Insufficient free GPU memory for dedicated LocalBackend smoke test: " + f"GPU {device} has {free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required." + ) + free_ratios.append(free_bytes / total_bytes) + return max(0.02, min(requested, min(free_ratios) * 0.8)) + + +def get_dedicated_vllm_test_config() -> art.dev.InternalModelConfig: + return { + "trainer_gpu_ids": [0], + "inference_gpu_ids": [1], + "engine_args": { + "gpu_memory_utilization": get_safe_gpu_memory_utilization(), + "max_model_len": int( + os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN)) + ), + "max_num_seqs": 8, + "enforce_eager": True, + }, + "init_args": { + "max_seq_length": int( + os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH)) + ), + }, + } + + +def reward_for_answer(text: str) -> float: + content = text.lower() + if "yes" in content: + return 1.0 + if "no" in content: + return 0.5 + if "maybe" in content: + return 0.25 + return 0.0 + + +async def assert_chat_logprobs( + client: openai.AsyncOpenAI, + model_name: str, +) -> None: + completion = await client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello."}], + model=model_name, + max_tokens=8, + timeout=60, + logprobs=True, + top_logprobs=0, + ) + assert completion.choices[0].logprobs is not None + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Need at least 2 CUDA GPUs for dedicated LocalBackend PipelineTrainer test", +) +async def test_pipeline_trainer_local_backend_dedicated_smoke() -> None: + model_name = f"test-pipeline-local-dedicated-{uuid.uuid4().hex[:8]}" + prompts = [ + "Say yes", + "Say no", + "Say maybe", + "Say hello", + "Say yes again", + "Say no again", + ] + + with tempfile.TemporaryDirectory() as tmpdir: + async with LocalBackend(path=tmpdir) as backend: + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + _internal_config=get_dedicated_vllm_test_config(), + ) + + async def scenario_iter(): + for prompt in prompts: + yield {"prompt": prompt} + + await model.register(backend) + client = model.openai_client() + try: + + async def rollout_fn( + rollout_model: art.TrainableModel, + scenario: dict[str, str], + _config: None, + ) -> art.TrajectoryGroup: + await asyncio.sleep(0.2) + messages: art.Messages = [ + {"role": "user", "content": scenario["prompt"]} + ] + completion = await client.chat.completions.create( + messages=messages, + model=rollout_model.get_inference_name(), + max_tokens=10, + timeout=60, + temperature=1, + n=2, + logprobs=True, + top_logprobs=0, + ) + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + for choice in completion.choices + ] + ) + + trainer = PipelineTrainer( + model=model, + backend=backend, + rollout_fn=rollout_fn, + scenarios=scenario_iter(), + config=None, + num_rollout_workers=2, + min_batch_size=1, + max_batch_size=1, + max_steps=2, + kl_penalty_coef=0.25, + loss_fn="cispo", + eval_fn=None, + ) + + await trainer.train() + + latest_step = await model.get_step() + assert latest_step >= 2 + + await assert_chat_logprobs(client, model.get_inference_name(step=0)) + await assert_chat_logprobs( + client, model.get_inference_name(step=latest_step) + ) + + model_ids = [m.id async for m in client.models.list()] + assert f"{model.name}@0" in model_ids + assert f"{model.name}@{latest_step}" in model_ids + + history_path = ( + Path(tmpdir) + / model.project + / "models" + / model.name + / "history.jsonl" + ) + history_rows = [ + json.loads(line) for line in history_path.read_text().splitlines() + ] + kl_values = [ + row["loss/kl_policy_ref"] + for row in history_rows + if "loss/kl_policy_ref" in row + ] + assert kl_values + assert all(math.isfinite(value) for value in kl_values) + finally: + await client.close() diff --git a/tests/integration/test_tinker_native_backend.py b/tests/integration/test_tinker_native_backend.py index 09ff33c4..4e5c61a5 100644 --- a/tests/integration/test_tinker_native_backend.py +++ b/tests/integration/test_tinker_native_backend.py @@ -9,6 +9,8 @@ import art from art.tinker_native import TinkerNativeBackend +from art.tinker_native.backend import _apply_kl_penalty +from art.tinker_native.data import trajectory_groups_to_datums DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" @@ -37,6 +39,8 @@ async def simple_rollout( max_tokens=10, timeout=60, temperature=1, + logprobs=True, + top_logprobs=0, ) choice = chat_completion.choices[0] content = (choice.message.content or "").lower() @@ -115,6 +119,85 @@ async def make_group(prompt: str) -> art.TrajectoryGroup: await backend.close() +@pytest.mark.skipif( + "TINKER_API_KEY" not in os.environ, + reason="TINKER_API_KEY not set - skipping TinkerNativeBackend KL test", +) +async def test_tinker_native_backend_kl_identity_metric(): + model_name = f"test-tinker-native-kl-{uuid.uuid4().hex[:8]}" + with tempfile.TemporaryDirectory() as tmpdir: + backend = TinkerNativeBackend(path=tmpdir) + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + ) + try: + await model.register(backend) + + openai_client = model.openai_client() + current_step = await model.get_step() + model_name_step = model.get_inference_name(step=current_step) + prompts = ["Say yes", "Say no", "Say maybe"] + + async def make_group(prompt: str) -> art.TrajectoryGroup: + import asyncio + + trajectories = await asyncio.gather( + *[ + simple_rollout(openai_client, model_name_step, prompt) + for _ in range(2) + ] + ) + return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined] + + train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined] + [make_group(prompt) for prompt in prompts] + ) + ensure_reward_variance(train_groups) + + state = backend._model_state[model.name] + datums = trajectory_groups_to_datums( + train_groups, + state.renderer, + state.tokenizer, + ) + assert datums + + reference_sampling_client = await backend._get_kl_reference_sampling_client( + state, + model.base_model, + current_step, + ) + expected_kl = ( + await _apply_kl_penalty( + trajectory_groups_to_datums( + train_groups, + state.renderer, + state.tokenizer, + ), + reference_sampling_client, + kl_penalty_coef=0.25, + ) + )["loss/kl_policy_ref"] + + result = await backend.train( + model, + train_groups, + learning_rate=1e-5, + kl_penalty_coef=0.25, + kl_penalty_reference_step=current_step, + ) + + assert result.metrics["loss/kl_policy_ref"] == pytest.approx( + expected_kl, + abs=0.05, + ) + assert result.metrics["loss/kl_policy_ref"] == pytest.approx(0.0, abs=0.05) + finally: + await backend.close() + + @pytest.mark.skipif( "TINKER_API_KEY" not in os.environ, reason="TINKER_API_KEY not set - skipping TinkerNativeBackend fork test", diff --git a/tests/unit/test_pipeline_trainer_batching.py b/tests/unit/test_pipeline_trainer_batching.py new file mode 100644 index 00000000..0ab412e8 --- /dev/null +++ b/tests/unit/test_pipeline_trainer_batching.py @@ -0,0 +1,67 @@ +import asyncio +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from art import TrainableModel, Trajectory, TrajectoryGroup +from art.pipeline_trainer.trainer import PipelineTrainer + + +def _make_group() -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=reward, + initial_policy_version=0, + messages_and_choices=[ + {"role": "user", "content": f"prompt-{idx}"}, + {"role": "assistant", "content": f"answer-{idx}"}, + ], + ) + for idx, reward in enumerate([0.0, 1.0]) + ] + ) + + +@pytest.mark.asyncio +async def test_collect_batch_respects_max_batch_size(tmp_path: Path) -> None: + model = TrainableModel( + name="pipeline-max-batch-size-test", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + trainer = PipelineTrainer( + model=model, + backend=MagicMock(), # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_batch_size=2, + max_steps=1, + eval_fn=None, + ) + trainer._output_queue = asyncio.Queue() + + first = _make_group() + second = _make_group() + third = _make_group() + await trainer._output_queue.put(first) + await trainer._output_queue.put(second) + await trainer._output_queue.put(third) + await trainer._output_queue.put(None) + + batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0) + + assert batch == [first, second] + assert discarded == 0 + assert not saw_sentinel + + batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0) + + assert batch == [third] + assert discarded == 0 + assert saw_sentinel diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py new file mode 100644 index 00000000..d17af12d --- /dev/null +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -0,0 +1,398 @@ +import asyncio +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from art import TrainableModel, Trajectory, TrajectoryGroup +from art.dev.model import InternalModelConfig +from art.local import LocalBackend +from art.pipeline_trainer.trainer import PipelineTrainer +from art.utils.output_dirs import get_model_dir + + +def _make_group(rewards: list[float]) -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=reward, + initial_policy_version=0, + messages_and_choices=[ + {"role": "user", "content": f"prompt-{idx}"}, + {"role": "assistant", "content": f"answer-{idx}"}, + ], + ) + for idx, reward in enumerate(rewards) + ] + ) + + +def _make_trainer( + *, + model: TrainableModel, + backend: object, + **kwargs: Any, +) -> PipelineTrainer: + return PipelineTrainer( + model=model, + backend=backend, # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_batch_size=1, + max_steps=1, + eval_fn=None, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_pipeline_trainer_preserves_backend_train_kwargs(tmp_path: Path) -> None: + model = TrainableModel( + name="pipeline-default-backend-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + loss_fn_config = {"alpha": 0.1} + adam_params = object() + + trainer = _make_trainer( + model=model, + backend=backend, + learning_rate=2e-5, + loss_fn="cispo", + loss_fn_config=loss_fn_config, + normalize_advantages=True, + adam_params=adam_params, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs == { + "learning_rate": 2e-5, + "loss_fn": "cispo", + "loss_fn_config": loss_fn_config, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": adam_params, + } + + +@pytest.mark.asyncio +async def test_pipeline_trainer_forwards_kl_kwargs_for_generic_backend( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-generic-backend-kl-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs == { + "learning_rate": 1e-5, + "loss_fn": "cispo", + "loss_fn_config": None, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": None, + "kl_penalty_coef": 0.25, + "kl_penalty_reference_step": 0, + "kl_penalty_source": "sample", + } + + +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_step_lag_floors_at_zero( + tmp_path: Path, +) -> None: + """kl_penalty_step_lag floors at step 0 when lag > current_step.""" + model = TrainableModel( + name="pipeline-kl-step-lag-floor", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=2, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_step_lag=5, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + # Simulate being at step 1 + trainer.state.next_training_step = 1 + + await trainer._training_stage() + + # At step 1, lag=5 → reference = max(0, 1-5) = 0 + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 0 + + +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_step_lag_computes_reference( + tmp_path: Path, +) -> None: + """kl_penalty_step_lag computes reference from current step.""" + model = TrainableModel( + name="pipeline-kl-step-lag", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=4, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_step_lag=2, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + # Simulate being at step 3 + trainer.state.next_training_step = 3 + + await trainer._training_stage() + + # At step 3, lag=2 → reference = max(0, 3-2) = 1 + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 1 + + +@pytest.mark.asyncio +async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-local-backend-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), + ) + backend = LocalBackend(path=str(tmp_path)) + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) # type: ignore[method-assign] + + trainer = _make_trainer( + model=model, + backend=backend, + learning_rate=3e-5, + loss_fn="ppo", + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs == { # type: ignore[attr-defined] + "learning_rate": 3e-5, + "loss_fn": "ppo", + "loss_fn_config": None, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": None, + } + + +@pytest.mark.asyncio +async def test_local_backend_train_translates_loss_fn(tmp_path: Path) -> None: + model = TrainableModel( + name="local-backend-train-translation", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group([1.0])], + loss_fn="ppo", + save_checkpoint=False, + ) + + assert result.step == 1 + assert seen["config"].learning_rate == 5e-6 + assert seen["dev_config"]["ppo"] is True + + +@pytest.mark.asyncio +async def test_local_backend_train_passes_kl_penalty_source(tmp_path: Path) -> None: + model = TrainableModel( + name="local-backend-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group([1.0])], + kl_penalty_coef=0.25, + kl_penalty_source="sample", + save_checkpoint=False, + ) + + assert result.step == 1 + assert seen["config"].kl_penalty_source == "sample" + assert seen["dev_config"]["kl_penalty_source"] == "sample" + + +@pytest.mark.asyncio +async def test_local_backend_async_context_manager_awaits_async_cleanup( + tmp_path: Path, +) -> None: + backend = LocalBackend(path=str(tmp_path)) + calls: list[str] = [] + + class FakeService: + async def aclose(self) -> None: + calls.append("aclose") + + service = FakeService() + backend._services["test-service"] = cast(Any, service) + + with patch("art.local.backend.close_proxy") as close_proxy: + async with backend: + pass + + assert calls == ["aclose"] + close_proxy.assert_called_once_with(service) + + +@pytest.mark.parametrize( + ("trainer_kwargs", "match"), + [ + ({"loss_fn": "dro"}, "loss_fn='cispo' or loss_fn='ppo'"), + ({"loss_fn_config": {"clip": 0.2}}, "loss_fn_config=None"), + ({"normalize_advantages": False}, "normalize_advantages=True"), + ({"adam_params": object()}, "adam_params=None"), + ], +) +def test_pipeline_trainer_rejects_unsupported_local_backend_settings( + tmp_path: Path, + trainer_kwargs: dict[str, object], + match: str, +) -> None: + model = TrainableModel( + name="pipeline-local-backend-invalid", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), + ) + + with pytest.raises(ValueError, match=match): + _make_trainer( + model=model, + backend=LocalBackend(path=str(tmp_path)), + **trainer_kwargs, + ) + + +def test_pipeline_trainer_rejects_shared_local_backend(tmp_path: Path) -> None: + model = TrainableModel( + name="pipeline-local-backend-shared", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + + with pytest.raises( + ValueError, match="only supports LocalBackend in dedicated mode" + ): + _make_trainer(model=model, backend=LocalBackend(path=str(tmp_path))) + + +def test_local_backend_inference_name_prefers_served_step_in_dedicated_mode( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="local-backend-served-step", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + _internal_config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + ), + ) + backend = LocalBackend(path=str(tmp_path)) + output_dir = Path(get_model_dir(model=model, art_path=str(tmp_path))) + (output_dir / "checkpoints" / "3").mkdir(parents=True) + backend._services[model.name] = cast(Any, SimpleNamespace(_latest_step=2)) + + assert backend._model_inference_name(model) == f"{model.name}@2" + assert backend._model_inference_name(model, step=3) == f"{model.name}@3" diff --git a/tests/unit/test_tinker_native_kl.py b/tests/unit/test_tinker_native_kl.py new file mode 100644 index 00000000..a2d16d01 --- /dev/null +++ b/tests/unit/test_tinker_native_kl.py @@ -0,0 +1,77 @@ +import pytest +import tinker + +from art import TrainableModel +from art.tinker_native.backend import TinkerNativeBackend, _apply_kl_penalty +from art.tinker_native.data import build_datum + + +class FakeSamplingClient(tinker.SamplingClient): + def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None: + self._responses = responses + + async def compute_logprobs_async( + self, prompt: tinker.ModelInput + ) -> list[float | None]: + return self._responses[tuple(prompt.to_ints())] + + +@pytest.mark.asyncio +async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None: + datum_a = build_datum( + prompt_tokens=[101, 102], + completion_tokens=[201, 202], + logprobs=[-0.4, -0.8], + advantage=1.0, + ) + datum_b = build_datum( + prompt_tokens=[301, 302], + completion_tokens=[401], + logprobs=[-0.2], + advantage=2.0, + ) + assert datum_a is not None + assert datum_b is not None + + sampling_client = FakeSamplingClient( + { + (101, 102, 201, 202): [None, -9.0, -0.1, -0.5], + (301, 302, 401): [None, -7.0, -0.05], + } + ) + + metrics = await _apply_kl_penalty( + [datum_a, datum_b], + sampling_client, + kl_penalty_coef=2.0, + ) + + assert metrics == {"loss/kl_policy_ref": pytest.approx(-0.25)} + assert datum_a.loss_fn_inputs["advantages"].tolist() == pytest.approx( + [0.0, 1.1, 1.1] + ) + assert datum_b.loss_fn_inputs["advantages"].tolist() == pytest.approx([0.0, 1.8]) + + +@pytest.mark.asyncio +async def test_tinker_native_backend_rejects_current_learner_kl_source( + tmp_path, +) -> None: + backend = TinkerNativeBackend(tinker_api_key="test-key", path=str(tmp_path)) + model = TrainableModel( + name="tinker-native-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + + with pytest.raises( + AssertionError, + match="only supports kl_penalty_source='sample'", + ): + await backend.train( + model, + [], + kl_penalty_coef=0.25, + kl_penalty_source="current_learner", # ty:ignore[invalid-argument-type] + )