Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,5 @@ werewolves_swarm
tensorboard_log
tutorial/**/*.json
node_modules
.agents
skills-lock.json
66 changes: 9 additions & 57 deletions ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run_ppo(config: DictConfig) -> None:
runtime_env = get_runtime_env(config)
ray.init(
runtime_env=runtime_env,
num_cpus=config.ray_init.num_cpus,
)

def on_shutdown():
Expand Down Expand Up @@ -93,12 +92,6 @@ def on_shutdown():
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))

# [Optional] get the path of the timeline trace file from the configuration, default to None
# This file is used for performance analysis
timeline_json_file = config.ray_init.get("timeline_json_file", None)
if timeline_json_file:
ray.timeline(filename=timeline_json_file)


@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
Expand Down Expand Up @@ -148,35 +141,25 @@ def run(self, config):
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
)
from ajet.backbone.verl import AjetActorRolloutRefWorker
from ajet.backbone.verl import AjetAsyncActorRolloutRefWorker


use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
if use_legacy_worker_impl in ["auto", "enable"]:
# import warnings
# warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
# Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.")
from verl.workers.fsdp_workers import CriticWorker
elif use_legacy_worker_impl == "disable":
from verl.workers.roles import CriticWorker
else:
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")

actor_rollout_cls = AsyncActorRolloutRefWorker

ActorRolloutRefWorker = AjetActorRolloutRefWorker
actor_rollout_cls = AjetAsyncActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
AjetAsyncActorRolloutRefWorker,
CriticWorker,
)

actor_rollout_cls = AsyncActorRolloutRefWorker
actor_rollout_cls = AjetAsyncActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup

else:
Expand All @@ -187,7 +170,6 @@ def run(self, config):
# Map roles to their corresponding remote worker classes.
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
Role.Critic: ray.remote(CriticWorker),
}

# Define the resource pool specification.
Expand All @@ -198,43 +180,15 @@ def run(self, config):
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}

# We should adopt a multi-source reward function here:
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# finally, we combine all the rewards together
# The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id

# Add a reference policy worker if KL loss or KL reward is used.
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id

# Load the reward manager for training and validation.
reward_fn = load_reward_manager(
config,
tokenizer,
num_examine=0,
**config.reward_model.get("reward_kwargs", {}),
)
val_reward_fn = load_reward_manager(
config,
tokenizer,
num_examine=1,
**config.reward_model.get("reward_kwargs", {}),
)

resource_pool_manager = ResourcePoolManager(
resource_pool_spec=resource_pool_spec, mapping=mapping
)
Expand Down Expand Up @@ -262,8 +216,6 @@ def run(self, config):
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
Expand Down
Loading
Loading