聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

概述

首发自个人公众号:阿郎小哥的随笔驿站

DeepSeek R1系列建议阅读之前的系列文章:

聊聊DeepSeek R1的一些总结

聊聊DeepSeek R1的开源复现库——Open R1之合成数据

聊聊DeepSeek R1的知识蒸馏与应用思考

简介

GRPO 是一种在线学习算法,这意味着它通过在训练期间使用受训模型自身生成的数据来迭代改进。GRPO 目标背后的直觉是最大化生成补全的优势,同时确保模型保持接近参考策略。

GRPO 的四个主要步骤:生成补全计算优势估计 KL 散度计算损失

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

与传统的RL方法不同,后者通常依赖外部评估者(批评者)来引导学习,GRPO通过评估一组响应之间的相对关系来优化模型。这种方法提高了训练效率,使GRPO在需要复杂问题解决和长链思维的推理任务中表现尤为出色。

步骤分解

步骤1:选择查询

• 从训练数据集$ P(Q) $中选择一个查询$ (q) $。

• 示例:假设查询是“8 + 5的和是多少?”

步骤2:生成一组响应

• 模型针对该查询生成一组$ G $个响应。

• 示例:模型生成以下响应:

• o1:“答案是13。”

• o2:“十三。”

• o3:“是12。”

• o4:“和是13。”

步骤3:计算每个响应的奖励

• 什么是奖励?奖励通过量化响应的质量来引导模型的学习。

• GRPO中的奖励类型:

• 准确性奖励:基于响应的正确性(例如,解答数学题)。

• 格式奖励:确保响应符合结构化要求(例如,推理过程需要包含在标签中)。

• 语言一致性奖励:惩罚语言混杂或格式不一致的响应。

• 根据每个响应的好坏,赋予一个奖励($ r_i $)。

例如,奖励可能取决于:

• 准确性:答案是否正确?

• 格式:响应是否结构良好?

示例:

• r1 = 1.0(正确且格式良好)

• r2 = 0.9(正确但较不正式)

• r3 = 0.0(错误答案)

• r4 = 1.0(正确且格式良好)

步骤4:比较响应(群体优势)

• 计算每个响应相对于群体的优势$ (A_i) $,paper中相关术语如下:

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

用简单的方式理解,就是这样:

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

• 比较结果优于群体平均水平的响应会获得正分,而表现较差的响应会得到负分。

• 这种方式在群体内部激发竞争,推动模型生成更好的响应。

步骤5:使用裁剪更新策略

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

示例:如果新策略开始给o1分配过高的概率,裁剪机制确保不会过度强调这个响应。

这种方式保证了即使在像推理这样复杂的任务中,策略优化也能保持稳定和可靠。

步骤6:通过KL散度惩罚偏差

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

GRPO实现

Open R1

在Open R1的复现路径中

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

实现了基于GRPO算法的训练,脚本如下

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yaml 

confg_full.yaml

# 基座模型 model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B model_revision: main torch_dtype: bfloat16  # 训练数据集 dataset_name: AI-MO/NuminaMath-TIR dataset_configs: - all # Num processes is less by 1 as vLLM is using 1 GPU num_processes: 7  # GRPO训练器参数 bf16: true use_vllm: true vllm_device: auto vllm_gpu_memory_utilization: 0.7 do_eval: true eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 16 gradient_checkpointing: true gradient_checkpointing_kwargs:   use_reentrant: false hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO hub_strategy: every_save learning_rate: 2.0e-05 log_level: info logging_steps: 10 logging_strategy: steps lr_scheduler_type: cosine max_prompt_length: 512 max_completion_length: 1024 max_steps: -1 num_train_epochs: 1 output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO overwrite_output_dir: true per_device_eval_batch_size: 4    per_device_train_batch_size: 1 push_to_hub: true report_to: - wandb save_strategy: "no" seed: 42 warmup_ratio: 0.1 

Open R1提供了grpo算法的实现——grpo.py,删减了部分无关代码,关键的程序逻辑如下:

@dataclass class GRPOScriptArguments(ScriptArguments):     reward_funcs: list[str] = field(         default_factory=lambda: ["accuracy", "format"],         metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},     )   def accuracy_reward(completions, solution, **kwargs):     """Reward function that checks if the completion is the same as the ground truth."""     contents = [completion[0]["content"] for completion in completions]     rewards = []     for content, sol in zip(contents, solution):         gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])         if len(gold_parsed) != 0:             # We require the answer to be provided in correct latex (no malformed operators)             answer_parsed = parse(                 content,                 extraction_config=[                     LatexExtractionConfig(                         normalization_config=NormalizationConfig(                             nits=False,                             malformed_operators=False,                             basic_latex=True,                             equations=True,                             boxed=True,                             units=True,                         ),                         # Ensures that boxed is tried first                         boxed_match_priority=0,                         try_extract_without_anchor=False,                     )                 ],                 extraction_mode="first_match",             )             # Reward 1 if the content is the same as the ground truth, 0 otherwise             reward = float(verify(answer_parsed, gold_parsed))         else:             # If the gold solution is not parseable, we reward 1 to skip this example             reward = 1.0             print("Failed to parse gold solution: ", sol)         rewards.append(reward)      return rewards   def format_reward(completions, **kwargs):     """Reward function that checks if the completion has a specific format."""     pattern = r"^<think>.*?</think><answer>.*?</answer>$"     completion_contents = [completion[0]["content"] for completion in completions]     matches = [re.match(pattern, content) for content in completion_contents]     return [1.0 if match else 0.0 for match in matches]   reward_funcs_registry = {     "accuracy": accuracy_reward,     "format": format_reward, }  SYSTEM_PROMPT = (     "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "     "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "     "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "     "<think> reasoning process here </think><answer> answer here </answer>" )   def main(script_args, training_args, model_args):         # Load the dataset     dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)      # Get reward functions     reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]      # Format into conversation     def make_conversation(example):         return {             "prompt": [                 {"role": "system", "content": SYSTEM_PROMPT},                 {"role": "user", "content": example["problem"]},             ],         }      dataset = dataset.map(make_conversation)     for split in dataset:         if "messages" in dataset[split].column_names:             dataset[split] = dataset[split].remove_columns("messages")      logger.info("*** Initializing model kwargs ***")     torch_dtype = (         model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)     )     model_kwargs = dict(         revision=model_args.model_revision,         trust_remote_code=model_args.trust_remote_code,         attn_implementation=model_args.attn_implementation,         torch_dtype=torch_dtype,         use_cache=False if training_args.gradient_checkpointing else True,     )     training_args.model_init_kwargs = model_kwargs      #############################     # Initialize the GRPO trainer     #############################     trainer = GRPOTrainer(         model=model_args.model_name_or_path,         reward_funcs=reward_funcs,         args=training_args,         train_dataset=dataset[script_args.dataset_train_split],         eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,         peft_config=get_peft_config(model_args),         callbacks=get_callbacks(training_args, model_args),     )      ###############     # Training loop     ###############     logger.info("*** Train ***")     checkpoint = None     if training_args.resume_from_checkpoint is not None:         checkpoint = training_args.resume_from_checkpoint     elif last_checkpoint is not None:         checkpoint = last_checkpoint     train_result = trainer.train(resume_from_checkpoint=checkpoint)     metrics = train_result.metrics     metrics["train_samples"] = len(dataset[script_args.dataset_train_split])     trainer.log_metrics("train", metrics)     trainer.save_metrics("train", metrics)     trainer.save_state()      ##################################     # Save model and create model card     ##################################     trainer.save_model(training_args.output_dir)      # Save everything else on main process     kwargs = {         "dataset_name": script_args.dataset_name,         "tags": ["open-r1"],     }     if trainer.accelerator.is_main_process:         trainer.create_model_card(**kwargs)         # Restore k,v cache for fast inference         trainer.model.config.use_cache = True         trainer.model.config.save_pretrained(training_args.output_dir)      ##########     # Evaluate     ##########     if training_args.do_eval:         logger.info("*** Evaluate ***")         metrics = trainer.evaluate()         metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])         trainer.log_metrics("eval", metrics)         trainer.save_metrics("eval", metrics)      #############     # push to hub     #############     if training_args.push_to_hub:         logger.info("Pushing to hub...")         trainer.push_to_hub(**kwargs)   if __name__ == "__main__":     parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))     script_args, training_args, model_args = parser.parse_args_and_config()     main(script_args, training_args, model_args) 

代码分析如下:

首先就是加载数据集,但数据集在加载时,会有指定的提示词,即代码中的make_conversation函数,该函数构造指定的prompt引导模型的输出,格式如下:

{     "prompt": [         {"role": "system", "content": SYSTEM_PROMPT},         {"role": "user", "content": example["problem"]},     ], } 

对于SYSTEM_PROMPT,描述如下:

"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "     "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "     "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "     "<think> reasoning process here </think><answer> answer here </answer>" 

总的来说就是,引导模型先思考推理过程,再按格式将推理过程与回复放入指定标签<think>、<answer>内。

接下来是reward函数,grpo算法有两种奖励:准确性奖励与格式正确奖励;如下

def accuracy_reward(completions, solution, **kwargs):     """Reward function that checks if the completion is the same as the ground truth."""     contents = [completion[0]["content"] for completion in completions]     rewards = []     for content, sol in zip(contents, solution):         gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])         if len(gold_parsed) != 0:             # We require the answer to be provided in correct latex (no malformed operators)             answer_parsed = parse(                 content,                 extraction_config=[                     LatexExtractionConfig(                         normalization_config=NormalizationConfig(                             nits=False,                             malformed_operators=False,                             basic_latex=True,                             equations=True,                             boxed=True,                             units=True,                         ),                         # Ensures that boxed is tried first                         boxed_match_priority=0,                         try_extract_without_anchor=False,                     )                 ],                 extraction_mode="first_match",             )             # Reward 1 if the content is the same as the ground truth, 0 otherwise             reward = float(verify(answer_parsed, gold_parsed))         else:             # If the gold solution is not parseable, we reward 1 to skip this example             reward = 1.0             print("Failed to parse gold solution: ", sol)         rewards.append(reward)      return rewards   def format_reward(completions, **kwargs):     """Reward function that checks if the completion has a specific format."""     pattern = r"^<think>.*?</think><answer>.*?</answer>$"     completion_contents = [completion[0]["content"] for completion in completions]     matches = [re.match(pattern, content) for content in completion_contents]     return [1.0 if match else 0.0 for match in matches]   reward_funcs_registry = {     "accuracy": accuracy_reward,     "format": format_reward, } 

最后就是训练,GRPOTrainer是transformers库提供的基于Trainer的训练类,传入指定的参数即可实现基于GRPO算法的实现;其中比较关键的是reward、train_dataset。

############################# # Initialize the GRPO trainer ############################# trainer = GRPOTrainer(     model=model_args.model_name_or_path,     reward_funcs=reward_funcs,     args=training_args,     train_dataset=dataset[script_args.dataset_train_split],     eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,     peft_config=get_peft_config(model_args),     callbacks=get_callbacks(training_args, model_args), )  

计算训练的checkpoint与循环周期,则会在Trainer类中通过gradient_accumulation_steps(梯度累积步数)、num_train_epochs(训练轮数)以及 per_device_train_batch_size(每个设备的训练批次大小)这些参数计算训练周期。

############### # Training loop ############### logger.info("*** Train ***") checkpoint = None if training_args.resume_from_checkpoint is not None:     checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None:     checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) 

小结

总的来说,Open R1的GRPO训练,是基于GRPOTrainer指定prompt/datasetreward等参数实现GRPO的训练。也就是说,在指定的训练数据集下,通过prompt引导模型的输出,然后基于grpo算法及其reward对 模型的输出与训练数据集的output 做奖惩打分(通过KL散度比较),计算loss,再反向传播。循环反复;最终完成模型的RL训练,达到让模型能做到CoT式的回复,即生成补全计算优势估计 KL 散度计算损失的步骤,如最开始的图所示。

对于GRPOTrainer类的源码及文档可参考:

首发自个人公众号:阿郎小哥的随笔驿站

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

发表评论

评论已关闭。

相关文章