Llama2-Chinese项目:8-TRL资料整理

news/2024/7/24 12:28:21 标签: Llama2, LLM, TRL

  TRL(Transformer Reinforcement Learning)是一个使用强化学习来训练Transformer语言模型和Stable Diffusion模型的Python类库工具集,听上去很抽象,但如果说主要是做SFT(Supervised Fine-tuning)、RM(Reward Modeling)、RLHF(Reinforcement Learning from Human Feedback)和PPO(Proximal Policy Optimization)等的话,肯定就很熟悉了。最重要的是TRL构建于transformers库之上,两者均由Hugging Face公司开发。

一.TRL类库
1.TRL类库介绍
简单理解就是可以通过TRL库做RLHF训练,如下所示:


(1)SFTTrainer:是一个轻量级、友好的transformers Trainer包装器,可轻松在自定义数据集上微调语言模型或适配器。
(2)RewardTrainer:是一个轻量级的transformers Trainer包装器,可轻松为人类偏好(奖励建模)微调语言模型。
(3)PPOTrainer:一个PPO训练器,用于语言模型,只需要(query, response, reward)三元组来优化语言模型。
(4)AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个带有额外标量输出的transformer模型,每个token都可以用作强化学习中的值函数。
(5)Examples:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j以减少毒性,Stack-Llama例子等。
2.PPO工作原理
  通过PPO对语言模型进行微调大致包括三个步骤:
(1)Rollout:语言模型根据query生成response或continuation,query可以是一个句子的开头。
(2)Evaluation:使用函数、模型、人类反馈或它们的某些组合对查询和响应进行评估。重要的是,此过程应为每个query/response对生成一个标量值。
(3)Optimization:这是最复杂的部分。在优化步骤中,query/response对用于计算序列中token的对数概率。这是使用经过训练的模型和Reference model完成的,Reference model通常是微调前的预训练模型。两个输出之间的KL散度用作额外的奖励信号,以确保生成的response不会偏离Reference model太远。然后使用PPO训练Active model。

二.TRL安装和使用方式
1.TRL安装

# 直接安装包
pip install trl

# 从源码安装
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .

2.SFTTrainer使用方式
  SFTTrainer是围绕transformer Trainer的轻量级封装,可以轻松微调自定义数据集上的语言模型或适配器。如下所示:

# 导入Python包
from datasets import load_dataset
from trl import SFTTrainer

# 加载imdb数据集
dataset = load_dataset("imdb", split="train")

# 得到trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# 开始训练
trainer.train()

3.RewardTrainer使用方式
  RewardTrainer是围绕transformers Trainer的封装,可以轻松在自定义偏好数据集上微调奖励模型或适配器。如下所示:

# 导入Python包
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# 加载模型和数据集,数据集需要为指定格式
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# 得到trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# 开始训练
trainer.train()

4.PPOTrainer使用方式
  query通过语言模型输出一个response,然后对其进行评估。评估可以人类反馈,也可以是另一个模型的输出。如下所示:

# 导入Python包
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# 首先加载模型,然后创建参考模型
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')

# 初始化ppo配置对象
ppo_config = PPOConfig(
    batch_size=1,
)

# 编码一个query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# 得到模型response
response_tensor  = respond_to_batch(model, query_tensor)

# 创建一个ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# 为response定义一个reward(人类反馈或模型输出奖励) 
reward = [torch.tensor(1.0)]

# 使用ppo训练一步模型
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

参考文献:
[1]https://github.com/huggingface/trl
[2]https://huggingface.co/docs/trl/v0.7.1/en/index


http://www.niftyadmin.cn/n/5067219.html

相关文章

Apollo Planning2.0决策规划算法代码详细解析 (2): vscode gdb单步调试环境搭建

前言: apollo planning2.0 在新版本中在降低学习和二次开发成本上进行了一些重要的优化,重要的优化有接口优化、task插件化、配置参数改造等。 GNU symbolic debugger,简称「GDB 调试器」,是 Linux 平台下最常用的一款程序调试器。GDB 编译器通常以 gdb 命令的形式在终端…

优思学院|六西格玛将烹饪和美味提升至极致

最近,我们曾提到一个美国男子如何利用六西格玛来控制糖尿病。这表明六西格玛逐渐被认为是一个不仅可以在工作场所之外使用,尤其不仅限于制造业的系统。 六西格玛的核心理念是改进过程的质量,从而改善最终结果。如果你做了晚餐或尝试了一道新…

动态规划五步曲

一、什么是动态规划五步曲 确定dp数组(dp table)以及下标的含义 确定递推公式 dp数组如何初始化 确定遍历顺序 举例推导dp数组 二、 个人赏析 这是我从某网站上看到的关于动态规划的教学系列。作为应试来说,这个 提纲确实不错,…

【1++的刷题系列】之双指针

👍作者主页:进击的1 🤩 专栏链接:【1的刷题系列】 文章目录 一,什么是双指针二,相关例题例一例二例三例四例五 一,什么是双指针 常见的双指针有两种形式:一种是对撞指针&#xff08…

【成像光敏描记图提取和处理】成像-光电容积描记-提取-脉搏率-估计(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

python实现进制转换

在Python中,可以编写一个函数来将一个数从一个进制转换到另一个进制。以下是一个简单的实现示例: def convert_base(num, base_from, base_to):"""将一个数从一个进制转换到另一个进制。参数:num: 需要转换的数base_from: 原始进制 (必须是2到36)base_to: 目…

语义分割 Semantic Segmentation

之前了解过语义分割的内容,感觉可以做好多东西,然后就抽空学习了一下,这里记录一下方便以后查阅,这篇文章可能也会随着学习的深入不断更新。 语义分割 Semantic Segmentation 一些基本概念几种语义分割算法Fully Convolutional Ne…

安装使用TinyCore Linux的一些收获

为了学习Linux Shell编程,决定安装一个纯粹的Linux,由于电脑硬件配置较低,选择了最轻量化Llinux操作系统版本TinyCore Linux。 一、TinyCore Linux有三个版本 打开TinyCore Linux的下载页面 http://www.tinycorelinux.net/downloads.html&a…