DAPO代码实现浅析

参考verl对dapo的实现,首先咱们看一下入口.sh和.py文件,在./recipe/dapo/文件夹中有以下目录

. ├── config │   ├── dapo_megatron_trainer.yaml │   └── dapo_trainer.yaml ├── dapo_ray_trainer.py ├── main_dapo.py ├── prepare_dapo_data.sh ├── README.md ├── run_dapo_qwen2.5_32b.sh 

整体的执行顺序:

  • main_dapo.py:数据加载初始化、初始化actor_rollout model、rm model,加载reward_manager
  • dapo_ray_trainer.py:RL训练流程
    • 对batch进行repeate,每个q采样n次
    • 记录每个采样的log,以及对应的reward_score 和 advantage
      • filter掉一个q的所有sample的score都是1或都是0,继续获取新的q进行采样,直到满足要求的batch的大小达到train_prompt_bsz。(值得注意的是,batch大小是gen_prompt_bsz=3*train_prompt_bsz,通过提高采样q的个数,避免满足要求的q不到train_prompt_bsz)。
    • 每mini_batch的data进行模型更新
      • 每micro_batch的data进行前向传播(token-mean loss)与梯度计算

具体代码实例:

main_dapo.py

# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # #     http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """  import os import socket  import hydra import ray from omegaconf import OmegaConf  from verl.trainer.ppo.reward import load_reward_manager from verl.utils.device import is_cuda_available  from .dapo_ray_trainer import RayDAPOTrainer   @hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) def main(config):     run_ppo(config)  ################################################################# # RL训练入口 ################################################################# def run_ppo(config) -> None:     if not ray.is_initialized():         # this is for local ray cluster         default_runtime_env = {             "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}         }         ray_init_kwargs = config.ray_kwargs.get("ray_init", {})         runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})         runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)         ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})         print(f"ray init kwargs: {ray_init_kwargs}")         ray.init(**OmegaConf.to_container(ray_init_kwargs))      try:         if (             is_cuda_available             and config.global_profiler.tool == "nsys"             and OmegaConf.select(config.global_profiler, "steps") is not None             and len(OmegaConf.select(config.global_profiler, "steps")) > 0         ):             nsight_options = OmegaConf.to_container(                 config.global_profiler.global_tool_config.nsys.controller_nsight_options             )             runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()         else:             runner = TaskRunner.remote()         ray.get(runner.run.remote(config))     finally:         if ray.is_initialized():             ray.shutdown()   @ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head class TaskRunner:     def run(self, config):         # print initial config         from pprint import pprint          from omegaconf import OmegaConf          from verl.utils.fs import copy_to_local          print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")          pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values         OmegaConf.resolve(config)          # download the checkpoint from hdfs         local_path = copy_to_local(config.actor_rollout_ref.model.path)          # instantiate tokenizer         from verl.utils import hf_processor, hf_tokenizer          tokenizer = hf_tokenizer(local_path)         processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none          from verl.single_controller.ray import RayWorkerGroup          #################################################################         # 加载actor worker         #################################################################         # define worker classes         if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:             assert config.critic.strategy in {"fsdp", "fsdp2"}              from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker              ray_worker_group_cls = RayWorkerGroup          elif config.actor_rollout_ref.actor.strategy == "megatron":             assert config.actor_rollout_ref.actor.strategy == config.critic.strategy             from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker              ray_worker_group_cls = RayWorkerGroup          else:             raise NotImplementedError          from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role          role_worker_mapping = {             Role.ActorRollout: ray.remote(ActorRolloutRefWorker),             Role.Critic: ray.remote(CriticWorker),         }          global_pool_id = "global_pool"         resource_pool_spec = {             global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,         }         mapping = {             Role.ActorRollout: global_pool_id,             Role.Critic: global_pool_id,         }          # we should adopt a multi-source reward function here         # - for rule-based rm, we directly call a reward score         # - for model-based rm, we call a model         # - for code related prompt, we send to a sandbox if there are test cases         # - finally, we combine all the rewards together         # - The reward type depends on the tag of the data         if config.reward_model.enable:             if config.reward_model.strategy in {"fsdp", "fsdp2"}:                 from verl.workers.fsdp_workers import RewardModelWorker             elif config.reward_model.strategy == "megatron":                 from verl.workers.megatron_workers import RewardModelWorker             else:                 raise NotImplementedError             role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)             mapping[Role.RewardModel] = global_pool_id          # reference model         if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:             role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)             mapping[Role.RefPolicy] = global_pool_id          #################################################################         # 加载reward manager函数。用于根据data计算对应的reward score         #################################################################         reward_fn = load_reward_manager(             config,             tokenizer,             0,             max_resp_len=config.data.max_response_length,             overlong_buffer_cfg=config.reward_model.overlong_buffer,         )          # Note that we always use function-based RM for validation         val_reward_fn = load_reward_manager(             config,             tokenizer,             1,             max_resp_len=config.data.max_response_length,             overlong_buffer_cfg=config.reward_model.overlong_buffer,         )         resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)          #################################################################         # 加载主要的DAPO RL训练类,并运行.fit()         #################################################################         trainer = RayDAPOTrainer(             config=config,             tokenizer=tokenizer,             processor=processor,             role_worker_mapping=role_worker_mapping,             resource_pool_manager=resource_pool_manager,             ray_worker_group_cls=ray_worker_group_cls,             reward_fn=reward_fn,             val_reward_fn=val_reward_fn,         )         trainer.init_workers()         trainer.fit()   if __name__ == "__main__":     main()  

我们紧接着来看一下from verl.trainer.ppo.reward import load_reward_manager
配置文件中verl/recipe/dapo/run_dapo_qwen2.5_32b.sh给出了reward的类型

enable_overlong_buffer=True overlong_buffer_len=$((1024 * 4)) # overlong soft overlong_penalty_factor=1.0  reward_model.reward_manager=dapo  reward_model.overlong_buffer.enable=${enable_overlong_buffer}  reward_model.overlong_buffer.len=${overlong_buffer_len}  reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor}  

verl.trainer.ppo.reward.py

def load_reward_manager(     config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any ) -> AbstractRewardManager:     """     Load and initialize a reward manager based on the configuration.      Args:         config: PPO trainer configuration object containing reward_model fields.         tokenizer: Tokenizer object used for processing text.         num_examine: Number of samples to examine.         **reward_kwargs: Additional keyword arguments for the reward manager.      Returns:         An instance of the specified reward manager class.     """      # Try to get a custom reward function based on the configuration     # user defined reward manager can be registered in custom_reward_fn     compute_score = get_custom_reward_fn(config)     final_compute_score = compute_score      # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:     # naive: NaiveRewardManager     # prime: PrimeRewardManager     # batch: BatchRewardManager     # dapo: DAPORewardManager     # Note(haibin.lin): For custom reward managers, please make sure they are imported and     # registered via `verl.workers.reward_manager.register`     # By default reward_manager is set to naive (NaiveRewardManager)     #################################################################     # 在这里加载具体的reward_manager     #################################################################     reward_manager_name = config.reward_model.get("reward_manager", "naive")     reward_manager_cls = get_reward_manager_cls(reward_manager_name)      if compute_score is None:         sandbox_config = config.reward_model.get("sandbox_fusion")         sandbox_url = sandbox_config.get("url") if sandbox_config else None         memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024)         if sandbox_url:             sandbox_manager = multiprocessing.Manager()             # Create a semaphore to control concurrent access to the sandbox             _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))             final_compute_score = partial(                 default_compute_score,                 sandbox_fusion_url=sandbox_url,                 concurrent_semaphore=_concurrent_semaphore,                 memory_limit_mb=memory_limit_mb,             )         else:             final_compute_score = default_compute_score      #################################################################     # 这里的reward_manager_cls 其实是DAPO,     #################################################################     # Instantiate and return the reward manager with the specified parameters     return reward_manager_cls(         tokenizer=tokenizer,         num_examine=num_examine,         compute_score=final_compute_score,         reward_fn_key=config.data.reward_fn_key,         **reward_kwargs,     ) 

这里需要知道dapo的reward_manager_cls 具体是什么,因为reward需要batch数据才能计算,因此对于reward manager咱们先按下不表(其实dapo对应的reward_manager_cls是在verl/verl/workers/reward_manager/dapo.py),先去dapo_ray_trainer.py看一下batch是怎么采样的,再回来仔细阅读reward的具体计算方法。

dapo_ray_trainer.py

################################################################# # RayDAPOTrainer继承于RayPPOTrainer # fit()函数:执行dapo的训练,包括(1)动态采样(2)overlong soft reward计算(3)token-level loss  ################################################################# class RayDAPOTrainer(RayPPOTrainer):     """     Note that this trainer runs on the driver process on a single CPU/GPU node.     """      def fit(self):         """         The training loop of PPO.         The driver process only need to call the compute functions of the worker group through RPC         to construct the PPO dataflow.         The light-weight advantage computation is done on the driver process.         """         from omegaconf import OmegaConf          from verl.utils.tracking import Tracking          logger = Tracking(             project_name=self.config.trainer.project_name,             experiment_name=self.config.trainer.experiment_name,             default_backend=self.config.trainer.logger,             config=OmegaConf.to_container(self.config, resolve=True),         )          self.global_steps = 0         self.gen_steps = 0          # load checkpoint before doing anything         self._load_checkpoint()          # perform validation before training         # currently, we only support validation using the reward_function.         if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):             val_metrics = self._validate()             assert val_metrics, f"{val_metrics=}"             pprint(f"Initial validation metrics: {val_metrics}")             logger.log(data=val_metrics, step=self.global_steps)             if self.config.trainer.get("val_only", False):                 return          if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):             rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)             rollout_skip.wrap_generate_sequences()          # add tqdm         progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")          # we start from step 1         self.global_steps += 1         self.gen_steps += 1         last_val_metrics = None          prev_step_profile = False         curr_step_profile = (             self.global_steps in self.config.global_profiler.steps             if self.config.global_profiler.steps is not None             else False         )         next_step_profile = False          timing_raw = defaultdict(float)         batch = None         #################################################################         # num_prompt_in_batch:记录filter后,std不等于0的q的个数,当模型更新后重新赋值为0         # num_gen_batches: 记录当前使用了多少个gen_batch,当模型更新后重新赋值为0         #################################################################         num_prompt_in_batch = 0         num_gen_batches = 0         #################################################################         # 正式开始训练,循环每个epoch后,循环每个gen_batch         #################################################################         for epoch in range(self.config.trainer.total_epochs):             for batch_dict in self.train_dataloader:                 metrics = {}                  with marked_timer("start_profile", timing_raw):                     self._start_profiling(                         not prev_step_profile and curr_step_profile                         if self.config.global_profiler.profile_continuous_steps                         else curr_step_profile                     )                  #################################################################                 # new_batch 是DataProto类型(具体见verl/verl/protocol.py),                 # new_batch.batch是TensorDict类型                 # new_batch中q的数量是可训练batch大小的3倍(增加采样的batch的q的个数)                 #################################################################                 new_batch: DataProto = DataProto.from_single_dict(batch_dict)                 num_gen_batches += 1                 # pop those keys for generation                 if "multi_modal_data" in new_batch.non_tensor_batch.keys():                     gen_batch = new_batch.pop(                         batch_keys=["input_ids", "attention_mask", "position_ids"],                         non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],                     )                 else:                     # 从new_batch中提取对应的key,构建gen_batch                     gen_batch = new_batch.pop(                         batch_keys=["input_ids", "attention_mask", "position_ids"],                         non_tensor_batch_keys=["raw_prompt_ids"],                     )                 # 这里为什么要repeate呢,因为每个prompt要采样n次,所以repeat n次。这里的interleave=True                 # gen_batch: (bsz, response_length),  								# gen_batch_output: (bsz*n, response_length)                 gen_batch_output = gen_batch.repeat(                     repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True                 )                  is_last_step = self.global_steps >= self.total_training_steps                  with marked_timer("step", timing_raw):                     # generate a batch                     with marked_timer("gen", timing_raw, "red"):                         gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)                         timing_raw.update(gen_batch_output.meta_info["timing"])                         gen_batch_output.meta_info.pop("timing", None)                      # 这个advatange 可以先忽略。RMAX需要先计算 贪心采样的sample的logits作为后序adv计算的baseline                     if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:                         with marked_timer("gen_max", timing_raw, "red"):                             gen_baseline_batch = deepcopy(gen_batch)                             # 这里是贪心采样的baseline,do_sample = False                             gen_baseline_batch.meta_info["do_sample"] = False                             gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)                              new_batch = new_batch.union(gen_baseline_output)                             # compute reward model score on new_batch                             rm_scores = None                             if self.use_rm and "rm_scores" not in new_batch.batch.keys():                                 rm_scores = self.rm_wg.compute_rm_score(new_batch)                                 new_batch = new_batch.union(rm_scores)                             reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)                             reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)                              keys_to_pop = set(gen_baseline_output.batch.keys())                             if rm_scores is not None:                                 keys_to_pop.update(rm_scores.batch.keys())                             new_batch.pop(batch_keys=list(keys_to_pop))                              new_batch.batch["reward_baselines"] = reward_baseline_tensor                              del rm_scores, gen_baseline_batch, gen_baseline_output                      #################################################################                     # new_batch的大小是gen_prompt_bsz                     # 对每一个prompt设置一个专属的标识 uid 										# 之所以设置uid,是因为之后对sample计算reward时,需要对同一个q的n个sample的reward标准化                     #################################################################                     new_batch.non_tensor_batch["uid"] = np.array(                         [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object                     )                     # 对batch中的每个key进行repeat(这里应该主要是对uid进行repeat)                     # repeat to align with repeated responses in rollout                     new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)                     # 把采样完的放到new_batch中                     new_batch = new_batch.union(gen_batch_output)                      with marked_timer("reward", timing_raw, "yellow"):                         # compute scores. Support both model and function-based.                         # We first compute the scores using reward model. Then, we call reward_fn to combine                         # the results from reward model and rule-based results.                         if self.use_rm and "rm_scores" not in new_batch.batch.keys():                             # we first compute reward model score                             reward_tensor = self.rm_wg.compute_rm_score(new_batch)                             new_batch = new_batch.union(reward_tensor)                          # 计算new_batch各个采样的reward,根据设置好的self.reward_fn                         # we combine with rule-based rm                         reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn)                          new_batch.batch["token_level_scores"] = reward_tensor                          if reward_extra_infos_dict:                             new_batch.non_tensor_batch.update(                                 {k: np.array(v) for k, v in reward_extra_infos_dict.items()}                             )                          # compute rewards. apply_kl_penalty if available                         if self.config.algorithm.use_kl_in_reward:                             new_batch, kl_metrics = apply_kl_penalty(                                 new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty                             )                             metrics.update(                                 kl_metrics                             )  # TODO: This will be cleared if we use multiple genenration batches                         else:                             new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] 										                     #################################################################                     # dapo的filter(dynamic sample)部分                     #################################################################                     if not self.config.algorithm.filter_groups.enable:                         batch = new_batch                     else:  # NOTE: When prompts after filtering is less than train batch size,                         # we skip to the next generation batch                         metric_name = self.config.algorithm.filter_groups.metric                         if metric_name == "seq_final_reward":                             # Turn to numpy for easier filtering                             new_batch.non_tensor_batch["seq_final_reward"] = (                                 new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()                             )                         elif metric_name == "seq_reward":                             new_batch.non_tensor_batch["seq_reward"] = (                                 new_batch.batch["token_level_scores"].sum(dim=-1).numpy()                             )                          # {uid: [r1,r2,r3,...,rn], uid: [...], ...},记录每个轨迹所有采样的reward                         # Collect the sequence reward for each trajectory                         prompt_uid2metric_vals = defaultdict(list)                         for uid, metric_val in zip(                             new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True                         ):                             prompt_uid2metric_vals[uid].append(metric_val)                          # 每个q的reward的std                         prompt_uid2metric_std = {}                         for prompt_uid, metric_vals in prompt_uid2metric_vals.items():                             prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)                          # 保留reward std不是0的q的uid                         kept_prompt_uids = [                             uid                             for uid, std in prompt_uid2metric_std.items()                             if std > 0 or len(prompt_uid2metric_vals[uid]) == 1                         ]                         # 累积std不是0的q                         num_prompt_in_batch += len(kept_prompt_uids)                          # 记录留下来的q的sample的idx                         kept_traj_idxs = []                         for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):                             if traj_from_prompt_uid in kept_prompt_uids:                                 kept_traj_idxs.append(idx)                          # 基于traj的id,检索对应的new_batch                         new_batch = new_batch[kept_traj_idxs]                         # batch是留下的traj数据的累积                         batch = new_batch if batch is None else DataProto.concat([batch, new_batch])                          # .sh文件配置的 可以训练的batch的最小大小(q的数量)                         prompt_bsz = self.config.data.train_batch_size                         # 如果现有的累积filter出来的q的数量小于 配置的最小数量,则continue继续使用下一个new_batch进行累积                         if num_prompt_in_batch < prompt_bsz:                             print(f"{num_prompt_in_batch=} < {prompt_bsz=}")                             max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches                             # max_num_gen_batches是最多可以使用的gen_batch的个数                             # 如果其小于0的话,即没有限制;若num_gen_batches < max_num_gen_batches则继续continue                             if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:                                 print(f"{num_gen_batches=}. Keep generating...")                                 self.gen_steps += 1                                 is_last_step = self.global_steps >= self.total_training_steps                                 continue                             else:                                 raise ValueError(                                     f"{num_gen_batches=} >= {max_num_gen_batches=}."                                     + " Generated too many. Please check if your data are too difficult."                                     + " You could also try set max_num_gen_batches=0 to enable endless trials."                                 )                       # 累积的符合的q个个数>=最小的可以训练的batch的大小                         else:                             # Align the batch                             traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n                             #################################################################                             # 对齐一下,多余的轨迹会被抛弃,不知道会不会导致采样的利用效率不高,                             # 会不会导致一些轨迹根本不会被训练到                             #################################################################                             batch = batch[:traj_bsz]                      #################################################################                     # actor模型更新                     #################################################################                     # === Updating ===                      batch.batch["response_mask"] = compute_response_mask(batch)                      # Balance the number of valid tokens across DP ranks.                     # NOTE: This usually changes the order of data in the `batch`,                     # which won't affect the advantage calculation (since it's based on uid),                     # but might affect the loss calculation (due to the change of mini-batching).                     # TODO: Decouple the DP balancing and mini-batching.                     if self.config.trainer.balance_batch:                         self._balance_batch(batch, metrics=metrics)                      # compute global_valid tokens                     batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()                      #################################################################                     # 记录filter后的batch的每个traj的采样时的logtis(token-level)                     # 用于计算重要性采样的比值                     #################################################################                     # recompute old_log_probs                     with marked_timer("old_log_prob", timing_raw, "blue"):                         old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)                         entropys = old_log_prob.batch["entropys"]                         response_masks = batch.batch["response_mask"]                         loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode                         # 这里dapo的loss_agg_mode是“token_mean”                         entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)                         old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}                         metrics.update(old_log_prob_metrics)                         old_log_prob.batch.pop("entropys")                         batch = batch.union(old_log_prob)                      if self.use_reference_policy:                         # compute reference log_prob                         with marked_timer("ref", timing_raw, "olive"):                             ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)                             batch = batch.union(ref_log_prob)                      # compute values                     if self.use_critic:                         with marked_timer("values", timing_raw, "cyan"):                             values = self.critic_wg.compute_values(batch)                             batch = batch.union(values)                      # 计算token_level的重要性采样                     # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)                     batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)                     # IS and mismatch metrics already have mismatch/ prefix                     metrics.update(is_metrics)                      #################################################################                     # 计算advantage                     #################################################################                     with marked_timer("adv", timing_raw, "brown"):                         # compute advantages, executed on the driver process                         norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)                         batch = compute_advantage(                             batch,                             adv_estimator=self.config.algorithm.adv_estimator,                             gamma=self.config.algorithm.gamma,                             lam=self.config.algorithm.lam,                             num_repeat=self.config.actor_rollout_ref.rollout.n,                             norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,                         )                      # update critic                     if self.use_critic:                         with marked_timer("update_critic", timing_raw, "pink"):                             critic_output = self.critic_wg.update_critic(batch)                         critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])                         metrics.update(critic_output_metrics)                      # implement critic warmup                     if self.config.trainer.critic_warmup <= self.global_steps:                         #################################################################                         # 更新actor model(batch的大小是train_prompt_size)                         # 每个mini_bsz 更新一次模型(参数-累积梯度)                         # 每个micro_bsz 累积一次梯度                         #################################################################                         # update actor                         with marked_timer("update_actor", timing_raw, "red"):                             actor_output = self.actor_rollout_wg.update_actor(batch)                         actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])                         metrics.update(actor_output_metrics)                      # Log rollout generations if enabled                     rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)                     if rollout_data_dir:                         self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)                  # validate                 if (                     self.val_reward_fn is not None                     and self.config.trainer.test_freq > 0                     and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)                 ):                     with marked_timer("testing", timing_raw, "green"):                         val_metrics: dict = self._validate()                         if is_last_step:                             last_val_metrics = val_metrics                     metrics.update(val_metrics)                  if self.config.trainer.save_freq > 0 and (                     is_last_step or self.global_steps % self.config.trainer.save_freq == 0                 ):                     with marked_timer("save_checkpoint", timing_raw, "green"):                         self._save_checkpoint()                  with marked_timer("stop_profile", timing_raw):                     next_step_profile = (                         self.global_steps + 1 in self.config.global_profiler.steps                         if self.config.global_profiler.steps is not None                         else False                     )                     self._stop_profiling(                         curr_step_profile and not next_step_profile                         if self.config.global_profiler.profile_continuous_steps                         else curr_step_profile                     )                     prev_step_profile = curr_step_profile                     curr_step_profile = next_step_profile                  # collect metrics                 metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))                 metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))                 # TODO: implement actual tflpo and theoretical tflpo                 n_gpus = self.resource_pool_manager.get_n_gpus()                 metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))                 timing_raw = defaultdict(float)  # clear timing                  metrics["train/num_gen_batches"] = num_gen_batches                 batch = None                 num_prompt_in_batch = 0                 num_gen_batches = 0                  # TODO: make a canonical logger that supports various backend                 logger.log(data=metrics, step=self.global_steps)                  if is_last_step:                     pprint(f"Final validation metrics: {last_val_metrics}")                     progress_bar.close()                     return                  progress_bar.update(1)                 self.global_steps += 1                 self.gen_steps += 1         # check if last step checkpint exists         checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")         if not os.path.exists(checkpoint_dir):             # save last step checkpoint             timing_raw = defaultdict(float)             with marked_timer("save_checkpoint", timing_raw, "green"):                 self._save_checkpoint()             metrics = {f"timing/{k}": v for k, v in timing_raw.items()}             logger.log(data=metrics, step=self.global_steps)  

这时候咱们再看一下dapo的reward manager实现:主要和ppo的区别在于使用了overlong_buffer,计算长度的reward

verl/verl/workers/reward_manager/dapo.py

################################################################# # 这里使用dapo注册了DAPORewardManager,因此可以用 # reward_manager_cls = get_reward_manager_cls(reward_manager_name)得到 ################################################################# @register("dapo") class DAPORewardManager(AbstractRewardManager):     """The reward manager."""      def __init__(         self,         tokenizer,         num_examine,         compute_score=None,         reward_fn_key="data_source",         max_resp_len=None,         overlong_buffer_cfg=None,     ) -> None:         self.tokenizer = tokenizer         self.num_examine = num_examine  # the number of batches of decoded responses to print to the console         self.compute_score = compute_score or default_compute_score         self.reward_fn_key = reward_fn_key         self.overlong_buffer_cfg = overlong_buffer_cfg         self.max_resp_len = max_resp_len          if self.overlong_buffer_cfg is not None:             assert self.max_resp_len is not None, (                 f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"             )             assert self.max_resp_len >= self.overlong_buffer_cfg.len, (                 "max_resp_len must be larger than overlong_buffer.len"             )      #################################################################     # DAPO reward manager的主要函数     #################################################################     def __call__(self, data: DataProto, return_dict: bool = False):         """We will expand this function gradually based on the available datasets"""          # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn         if "rm_scores" in data.batch.keys():             if return_dict:                 reward_extra_keys = data.meta_info.get("reward_extra_keys", [])                 reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}                 return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}             else:                 return data.batch["rm_scores"]           reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)         reward_extra_info = defaultdict(list)          already_print_data_sources = {}          for i in range(len(data)):             data_item = data[i]  # DataProtoItem              prompt_ids = data_item.batch["prompts"]              prompt_length = prompt_ids.shape[-1]              ########################################################             # 值得注意的是。prompt_ids是左填充的             # response_ids是右填充的             ########################################################             valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()             valid_prompt_ids = prompt_ids[-valid_prompt_length:]              response_ids = data_item.batch["responses"]             valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()             valid_response_ids = response_ids[:valid_response_length]              # decode             prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)             response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)             eos_token = self.tokenizer.eos_token             if response_str.endswith(eos_token):                 response_str = response_str[: -len(eos_token)]              ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]              data_source = data_item.non_tensor_batch[self.reward_fn_key]              extra_info = data_item.non_tensor_batch.get("extra_info", {})              rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})              extra_info["rollout_reward_scores"] = rollout_reward_scores              result = self.compute_score(                 data_source=data_source,                 solution_str=response_str,                 ground_truth=ground_truth,                 extra_info=extra_info,             )              score: float             if isinstance(result, dict):                 score = result["score"]                 # Store the information including original reward                 for key, value in result.items():                     reward_extra_info[key].append(value)             else:                 score = result                 reward_extra_info["acc"].append(score)              reward = score              ########################################################             # 这里是overlong reward的计算             ########################################################             if self.overlong_buffer_cfg.enable:                 overlong_buffer_len = self.overlong_buffer_cfg.len                 expected_len = self.max_resp_len - overlong_buffer_len                 exceed_len = valid_response_length - expected_len                 overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor                 overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)                 reward += overlong_reward                 if self.overlong_buffer_cfg.log:                     reward_extra_info["overlong_reward"].append(overlong_reward)                     reward_extra_info["overlong"].append(overlong_reward < 0)              reward_tensor[i, valid_response_length - 1] = reward              if data_source not in already_print_data_sources:                 already_print_data_sources[data_source] = 0              if already_print_data_sources[data_source] < self.num_examine:                 already_print_data_sources[data_source] += 1                 print("[prompt]", prompt_str)                 print("[response]", response_str)                 print("[ground_truth]", ground_truth)                 if isinstance(result, dict):                     for key, value in result.items():                         print(f"[{key}]", value)                 else:                     print("[score]", score)          if return_dict:             return {                 "reward_tensor": reward_tensor,                 "reward_extra_info": reward_extra_info,             }         else:             return reward_tensor  

dapo和ppo的具体区别可进一步参考:dapo readme

发表评论

评论已关闭。

相关文章