GitHub: https://github.com/PeterGriffinJin/Search-R1
Motivation
使用seach engine给reasoning LLM赋能
Method
在PPO的基础上,基于给定的Search Egine (R),进行轨迹生成。
其中需要对(R)返回的token进行mask
Experiments
默认使用PPO,整体效果来看search-r1强化是有效的。training dataset来自NQ和Hotpot QA
-
PPO vs GRPO
认为PPO比GRPO更加稳定,效果更好;GRPO收敛更快


-
Instruct model vs base model
认为虽然instruct model在最开始的reward要优于base model,但是在step的后期,两者reward是可比的,且base model的效果优于instruct model。
(我认为,这里instruct好于base,可能是因为instruct后,模型的多样性下降了(因为RL的对齐),导致模型在search task的探索能力下降。但是,WebDancer等文章均使用的是Instruct model,我认为是那些工作 并不是一上来就search RL的,而是先做RFT的SFT,想让instruct model适应RL的格式,并注入search task的领域知识(planing能力、工具调用能力、总结能力等等)。如果是对base model做post-training的RFT(数据量可能不大),base model会出现指令不遵循的问题。因此在SFT+RL的后续WebAgent的工作中,一半以Instruct model为基座。)


-
Response length and valid study
- early stage:response length明显下降,同时reward有小幅度提升(更好的理解search 任务,输出更精简)
- latter stage:response length回升,reward也提升(可以发现是seach call的次数提升导致)

-
ablation of retrived token mask
mask是必要的,因为model的预测目标本就不是 预测出retrieved token,而是学会工具调用与计划总结


-
Number of Retrieved Passages Study in SEARCH-R1 Training
召回的docs不是越多越好(actor model总结时会更容易出现幻觉或是遗漏细节),也不是越少越好(巧妇难为无米之炊)

-
group size of GRPO
GRPO的size 大的话,效果好收敛快,但是不太稳定(感觉是论文工作设计有问题,我没有遇到过这种reward sharp decrease)

Conclusion
提出了agent下的RL方法,但是没有构建sft的轨迹数据,导致无法学到 planing规划、单一工具调用、多工具关系的能力。
代码实现
Agent-RL的代码实现难点在于以下两方面,我将会对比naive RL和search-r1的在以下两方面的代码进行解析
- traj的loop 生成
- traj的reward manager
1. loop生成轨迹数据
区别于naive的RL,search-r1需要提取每步的action和tool,并进行retrieve调用。
首先咱们先来看一下verl在verl.trainer.ppo.ray_trainer.py调用的self.actor_rollout_wg.generate_sequences(gen_batch_output)的navie实现。
verl/workers/rollout/naive/naive_rollout.py。值得注意的是,rollout是采样,不需要保存计算图的,使用@torch.no_grad
class NaiveRollout(BaseRollout): def __init__(self, module: nn.Module, config): """A naive rollout. It requires the module to be compatible with huggingface APIs. That is: The module should define __call__ to receive input_ids, attention_mask and position_ids. It outputs a structure that contains logits field. Args: module: module here follows huggingface APIs config: DictConfig """ super().__init__() self.config = config self.module = module ######################################################################### # rollout 不保存计算图 ######################################################################### @torch.no_grad() def generate_sequences(self, prompts: DataProto) -> DataProto: """Generate sequences""" ######################################################################### # 值得注意的是 如果是grpo,那么这里batch['input_ids']的shape是(batch_size*rollout.n, prompt_length)的 # 在ray_trainer.py里面有先做repeat操作 ######################################################################### idx = prompts.batch['input_ids'] # (bs, prompt_length) attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask position_ids = prompts.batch['position_ids'] # used to construct attention_mask eos_token_id = prompts.meta_info['eos_token_id'] batch_size = idx.size(0) prompt_length = idx.size(1) self.module.eval() # 这里的pre_attention_mask是记录每一个sequence是否已经rollout完毕 # 即 在当前iter生成的token之前 是否已经出现过 eos_token prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device) logits_lst = [] ######################################################################### # 这里整体的思路是,每个迭代iter 同步生成所有sequence的同一位置(position_id)的 next_token_id # 并且循环 response_length次,无论是否遇到eos_id # 这么做的目的在于,基于矩阵操作并行地生成所有sequence,而不是每个sequence的生成,保证rollout效率 ######################################################################### for _ in range(self.config.response_length): # if the sequence context is growing too long we must crop it at block_size # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] idx_cond = idx # forward the model to get the logits for the index in the sequence # we use huggingface APIs here output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids) # logits: (bs, hidden_layer_num, vocab_size) logits = output.logits ######################################################################### # 下面是一些采样的操作 # temperature: 每个token的所有的vocab的logit/temp # topk: 把非topk的vocab 的logit 赋值为-inf,不影响后续的softmax,忽略这些低概率的vocab # do_sample: 是概率采样 或是 选择概率最大的idx ######################################################################### # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size) # optionally crop the logits to only the top k options if self.config.top_k is not None: v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution if self.config.do_sample: idx_next = torch.multinomial(probs, num_samples=1) else: idx_next = torch.argmax(probs, dim=-1, keepdim=True) ######################################################################### # 下面进行拼接 # attention_mask # position_ids # idx ######################################################################### # 将当前token的mask拼接到之前的attention_mask上 # 其实当前token是否被mask主要看 之前的token是否出现 eos_token attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) # 如果当前token是eos_token或之前出现过eos_token,那么之后的所有token都应该是被mask掉的 prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool()) prev_attention_mask.to(attention_mask.dtype) position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1) logits_lst.append(logits) # 将[(bs, vocab_size), ..., (bs, vocab_size)] 一共resp_length个 在1维度上进行堆叠 logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size) prompts = idx[:, :prompt_length] # (bs, prompt_length) response = idx[:, prompt_length:] # (bs, response_length) # 获取采样的每个token的概率(一般就是softmax一下,再根据response进行检索) log_probs = logprobs_from_logits(logits=logits, labels=response) batch = TensorDict( { 'input_ids': prompts, 'responses': response, 'sequences': idx, 'old_log_probs': log_probs, 'attention_mask': attention_mask, 'position_ids': position_ids, }, batch_size=batch_size) self.module.train() return DataProto(batch=batch)
可以发现的是,batch的response相当于是右填充,因为每个seq首次出现的eos_idx的后面的attnetion_mask都是1,具体是以下代码导致的:
# 将当前token的mask拼接到之前的attention_mask上 # 其实当前token是否被mask主要看 之前的token是否出现 eos_token attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) # 如果当前token是eos_token或之前出现过eos_token,那么之后的所有token都应该是被mask掉的 prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
好了,看完naive的一个batch的sequences的generate流程,我们需要进一步看一下agent的traj的生成。
traj可以简单地认为是naive sequence的loop,但是需要对在每个step生成的sequence进行decode,来解析工具,并将工具调用的结果拼接到sequence的后面作为prompt,进行后续step的生成。
search-r1的训练流程为verl.trainer.ppo.ray_trainer.py,与原始verl最大的区别在于使用了新的LLMGenerationManager.run_llm_loop()方法以生成agent traj,因此我们先阅读这个主要模块:search_r1.llm_agent.generation.py
@dataclass class GenerationConfig: max_turns: int # 最大开始prompt长度 max_start_length: int # 最大累积prompt长度(start+(repsonse+obser)*step) max_prompt_length: int # 最大单次生成response的长度 max_response_length: int # 最大工具返回内容的长度 max_obs_length: int num_gpus: int # 是否需要think no_think_rl: bool=False # search engine的url search_url: str = None # 召回docs的个数 topk: int = 3 class LLMGenerationManager: ... ################################################################# # 生成agent traj数据,循环config.max_turns轮,每个traj最多是由max_turns*[sequence]拼接得到的 ################################################################# def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor) -> Tuple[Dict, Dict]: """Run main LLM generation loop.""" ################################################################# # 下面初始化一些全局变量,用于维护 batch中每一个traj在 每个轮次turn的 # prompt response mask status action_stats search_stats ################################################################# # 左填充 original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]} # 右填充 original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]} # 当前轮次 每个taj是否是active的(是否未完成且无异常):(bsz*rollout.n) # 若active_mask = 0,那么这个example可能是结果了或是异常了,就不再进行后续turn的生成了 active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool) # 每个traj的active turn的总数(这个traj的turn总数) turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int) # 每个traj的action的总数(不一定等于turns_stats,因为有些turn可能action是错误的,不在(answer, search)中) valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int) # 每个traj的search action的总数( 一般是turns_stats - answer_num(一般是1) ) valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int) # 每个轮次中 活跃的traj的数量 active_num_list = [active_mask.sum().item()] rollings = gen_batch ################################################################# # 下面开始 轮次循环,每个轮次需要生成response+提取工具+调用工具+获取obs+拼接prompt ################################################################# # Main generation loop for step in range(self.config.max_turns): if not active_mask.sum(): break rollings.batch = self.tensor_fn.cut_to_effective_len( rollings.batch, keys=['input_ids', 'attention_mask', 'position_ids'] ) # gen_output = self.actor_rollout_wg.generate_sequences(rollings) # 仅筛选出还是active的traj(根据active_mask) rollings_active = DataProto.from_dict({ k: v[active_mask] for k, v in rollings.batch.items() }) # 这里先认为num_gpus是1, 并没有data-paralle,直接就是gen_output = self.actor_rollout_wg.generate_sequences(rollings)生成response gen_output = self._generate_with_gpu_padding(rollings_active) meta_info = gen_output.meta_info # 对responses (bsz*rollout.n, response_length)做后处理 # 先做decode,将token ids解码成字符串,提取<search></search>包裹的动作或是<answer></answer>包裹的答案 # 并对提取出的search或是answer重新进行encode得到ids返回(右填充) responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses']) # 根据active_mask,将不active的exmaple的ids用pad_token填充,其str为"" # 因为active的example数量小于等于 batch的大小,为了填充batch,则需要padding not active example responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask) # 调用search engine,返回所有exmaple的 【docs、是否成功、action是否合理、是否是search动作】 的列表 # Execute in environment and process observations next_obs, dones, valid_action, is_search = self.execute_predictions( responses_str, self.tokenizer.pad_token, active_mask ) # 如果done的话,那么就要mask掉,因此是0 curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool) # 上个turn是not active,那么这个轮次还是 active_mask = active_mask * curr_active_mask active_num_list.append(active_mask.sum().item()) turns_stats[curr_active_mask] += 1 valid_action_stats += torch.tensor(valid_action, dtype=torch.int) valid_search_stats += torch.tensor(is_search, dtype=torch.int) # 处理observation,先右填充编码,再进行max_obs_length截断,获取靠右左侧的obs next_obs_ids = self._process_next_obs(next_obs) # 将每traj的rolling+response_ids+next_obs_ids进行拼接 # 注意:rolling是左填充,response obs均是右填充 # 因此 拼接完成后,还需要将中间的padding ids移动到左侧,保持其他token的原有位置,继续维持rollings的左填充 # Update states rollings = self._update_rolling_state( rollings, responses_ids, next_obs_ids ) # 同样是拼接 original_right_side+response+obs # 但保持右填充 original_right_side = self._update_right_side( original_right_side, responses_ids, next_obs_ids ) # 可能存在一些example经过max_turns次循环后,还是没有得到answer,导致没有not active # final LLM rollout if active_mask.sum(): rollings.batch = self.tensor_fn.cut_to_effective_len( rollings.batch, keys=['input_ids', 'attention_mask', 'position_ids'] ) # gen_output = self.actor_rollout_wg.generate_sequences(rollings) rollings_active = DataProto.from_dict({ k: v[active_mask] for k, v in rollings.batch.items() }) gen_output = self._generate_with_gpu_padding(rollings_active) meta_info = gen_output.meta_info responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses']) responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask) # # Execute in environment and process observations _, dones, valid_action, is_search = self.execute_predictions( responses_str, self.tokenizer.pad_token, active_mask, do_search=False ) curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool) active_mask = active_mask * curr_active_mask active_num_list.append(active_mask.sum().item()) valid_action_stats += torch.tensor(valid_action, dtype=torch.int) valid_search_stats += torch.tensor(is_search, dtype=torch.int) original_right_side = self._update_right_side( original_right_side, responses_ids, ) meta_info['turns_stats'] = turns_stats.tolist() meta_info['active_mask'] = active_mask.tolist() meta_info['valid_action_stats'] = valid_action_stats.tolist() meta_info['valid_search_stats'] = valid_search_stats.tolist() print("ACTIVE_TRAJ_NUM:", active_num_list) return self._compose_final_output(original_left_side, original_right_side, meta_info) # 拼接origin_left+累积的模型输出和工具调用 def _compose_final_output(self, left_side: Dict, right_side: Dict, meta_info: Dict) -> Tuple[Dict, Dict]: """Compose final generation output.""" final_output = right_side.copy() final_output['prompts'] = left_side['input_ids'] # Combine input IDs final_output['input_ids'] = torch.cat([ left_side['input_ids'], right_side['responses'] ], dim=1) # Create attention mask and position ids final_output['attention_mask'] = torch.cat([ self.tensor_fn.create_attention_mask(left_side['input_ids']), self.tensor_fn.create_attention_mask(final_output['responses']) ], dim=1) final_output['info_mask'] = torch.cat([ self.tensor_fn.create_attention_mask(left_side['input_ids']), self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask']) ], dim=1) final_output['position_ids'] = self.tensor_fn.create_position_ids( final_output['attention_mask'] ) final_output = DataProto.from_dict(final_output) final_output.meta_info.update(meta_info) return final_output
咱们再回过来看一下search-r1的rl流程 ray_trainer.py
######################################################################### # search-r1是直接在verl的trainer.ppo.ray_trainer.py的源码上进行扩展 # 添加了新的 generate_mannager用于生成agent traj(将在下一个代码框进行介绍) # 我们先来看一下search-r1的整体训练流程 ######################################################################### def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ logger = self.logger self.global_steps = 0 # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): val_metrics = self._validate() pprint(f'Initial validation metrics: {val_metrics}') logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get('val_only', False): return # we start from step 1 self.global_steps += 1 ######################################################################### # 这里是新添加的agent traj轨迹数据的generate模块 # Agent config preparation gen_config = GenerationConfig( max_turns=self.config.max_turns, max_start_length=self.config.data.max_start_length, max_prompt_length=self.config.data.max_prompt_length, max_response_length=self.config.data.max_response_length, max_obs_length=self.config.data.max_obs_length, num_gpus=self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes, no_think_rl=self.config.algorithm.no_think_rl, search_url = self.config.retriever.url, topk = self.config.retriever.topk, ) generation_manager = LLMGenerationManager( tokenizer=self.tokenizer, actor_rollout_wg=self.actor_rollout_wg, config=gen_config, ) ######################################################################### ######################################################################### # 这里的loop还是verl的源码,循环每一个train epoch # start training loop for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: print(f'epoch {epoch}, step {self.global_steps}') metrics = {} timing_raw = {} # 获取一个batch的训练数据 (bsz, prompt_length) # 并进行repeat(grpo需要repeat) # 注意:prompt是左填充的 batch: DataProto = DataProto.from_single_dict(batch_dict) batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n_agent, interleave=True) # pop those keys for generation gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) #################### # original code here with _timer('step', timing_raw): if not self.config.do_search: gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) ######################################################################### # 这里就是新的search-r1的训练流程了 ######################################################################### #################### # Below is aLL about agents - the "LLM + forloop" #################### # with _timer('step', timing_raw): else: # 这里先做了一个左截断,仅保留靠右的max_start_length的prompt ids first_input_ids = gen_batch.batch['input_ids'][:, -gen_config.max_start_length:].clone().long() with _timer('gen', timing_raw): generation_manager.timing_raw = timing_raw # 这里生成数据 (bsz*rollout.n, prompt_length+response_length) final_gen_batch_output = generation_manager.run_llm_loop( gen_batch=gen_batch, initial_input_ids=first_input_ids, ) # final_gen_batch_output.batch.apply(lambda x: x.long(), inplace=True) for key in final_gen_batch_output.batch.keys(): final_gen_batch_output.batch[key] = final_gen_batch_output.batch[key].long() with torch.no_grad(): output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output) final_gen_batch_output = final_gen_batch_output.union(output) # batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], # dtype=object) # 看来是输入的时候记录了每个q的index在non_tensor中 batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy() # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(final_gen_batch_output) #################### #################### # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() # batch.batch.apply(lambda x, key: x.long() if key != "old_log_probs" else x, inplace=True, key=True) for key in batch.batch.keys(): if key != 'old_log_probs': batch.batch[key] = batch.batch[key].long() if self.use_reference_policy: # compute reference log_prob with _timer('ref', timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: with _timer('values', timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) with _timer('adv', timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. if self.use_rm: # we first compute reward model score reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) # we combine with rule-based rm reward_tensor = self.reward_fn(batch) batch.batch['token_level_scores'] = reward_tensor # compute rewards. apply_kl_penalty if available if not self.config.actor_rollout_ref.actor.use_kl_loss: batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty) metrics.update(kl_metrics) else: batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] # compute advantages, executed on the driver process batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, gamma=self.config.algorithm.gamma, lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n) # update critic if self.use_critic: with _timer('update_critic', timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor with _timer('update_actor', timing_raw): if self.config.do_search and self.config.actor_rollout_ref.actor.state_masking: batch, metrics = self._create_loss_mask(batch, metrics) actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) metrics.update(actor_output_metrics) # validate if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0: with _timer('testing', timing_raw): val_metrics: dict = self._validate() metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: with _timer('save_checkpoint', timing_raw): self._save_checkpoint() # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) self.global_steps += 1 if self.global_steps >= self.total_training_steps: # perform validation after training if self.val_reward_fn is not None: val_metrics = self._validate() pprint(f'Final validation metrics: {val_metrics}') logger.log(data=val_metrics, step=self.global_steps) return
2. Tool use
待更新