OpenLLaMA2
最近联合 xianyu(https://github.com/catqaq) 等 OpenLLMAI 开源组织的同学搞了一个开源项目(欢迎 star)
GitHub - OpenLLMAI/OpenLLaMA2: DeepSpeed+Ray based LLaMA2 SFT/RLHF training frameworkgithub.com/OpenLLMAI/OpenLLaMA2
主要目的是基于 Ray+DeepSpeed 打造一个分布式 SFT/RLHF 训练框架,最终目的是基于 llama2 模型从预训练到 RLHF 都做一遍,目前正在拉资源和人力。最近我们完成了第一版基于deepspeed的代码,并且实现了在单卡 A100 80G 上可以训练 7B (actor)+7B (critic) 的 llama2 PPO 模型(得益于我们用各种trick节省了显卡内存,并同时保证了相当快的训练速度) ,虽然最终为了节省训练时间仍然使用4卡训练。相对来说,Huggingface TRL 使用的actor/critic共享模型,我们这里是独立的的模型,使得 Critic 可以用 RM 权重初始化。毕竟我们最终的目标工业级可应用框架
本文主要介绍一下框架快速使用方法,因为其已经内置所有数据集和可收敛的超参数,所以可以一键使用。并且提供了基于开源数据训练过的 SFT/RM 模型。
SFT/RM模型地址
chuyi777/openllama2_checkpoint · Hugging Face
Env
首先我们要启动 docker 环境避免各种复杂依赖项带来的问题,脚本如下
https://github.com/OpenLLMAI/OpenLLaMA2/blob/main/examples/scripts/docker_run.sh
PROJECT_PATH=$(cd $(dirname $0)/../../; pwd)
IMAGE_NAME="nvcr.io/nvidia/pytorch:23.05-py3"
docker run --runtime=nvidia -it --rm --shm-size="1g" --cap-add=SYS_ADMIN \
-u $(id -u):$(id -g) \
-v $PROJECT_PATH:/openllama2 -v $HOME/.cache:/.cache -v $HOME/.bash_history2:/.bash_history \
-v $HOME/.local:/.local -v $HOME/.triton:/.triton \
$IMAGE_NAME bash
这个脚本会自动挂载项目路径以及 .cache 等缓存路径(为了复用 huggingface 模型数据缓存)
然后我们需要在container中的项目跟目录 /openllama2 下执行
python -m pip install .
安装 deepspeed 等未在NVIDIA container 中安装的依赖项,如果确定自己的环境很OK不会报错也可以不使用container直接这样安装。
SFT
然后是训练 SFT 模型,这里默认使用500k 的 ocra 数据集,可以自行更换。
https://github.com/OpenLLMAI/OpenLLaMA2/blob/main/examples/scripts/train_sft_llama.sh
deepspeed ../train_sft.py \
--max_len 2048 \
--dataset 'Open-Orca/OpenOrca' \
--dataset_probs '1.0' \
--train_batch_size 128 \
--micro_train_batch_size 1 \
--max_samples 500000 \
--pretrain "meta-llama/Llama-2-7b-hf" \
--save_path "./ckpt/7b_llama" \
--zero_stage 2 \
--max_epochs 1 \
--bf16 \
--learning_rate 5e-6 \
--gradient_checkpointing
注意:这个配置是4卡的 如果你使用单卡。请增加选项 --adam_offload,但速度会比较慢。
RM
接下来是 RM 模型,我们混了了好几个开源的人工标注数据集(Anthorpic,LMsys,oasst,webgpt),以helpfulness为主。
其中 --datasets 指定的就是混合数据集的列表
--dataset_probs 指定对应的采样概率
这个混合数据的方法在 SFT/PPO模块同样适用
--dataset 'Anthropic/hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward,lmsys/chatbot_arena_conversations,openai/webgpt_comparisons
https://github.com/OpenLLMAI/OpenLLaMA2/blob/main/examples/scripts/train_rm_llama.sh
deepspeed ../train_rm.py \
--save_path "./ckpt/7b_llama" \
--train_batch_size 128 \
--micro_train_batch_size 1 \
--pretrain "meta-llama/Llama-2-7b-hf" \
--bf16 \
--max_epochs 1 \
--max_len 1024 \
--zero_stage 3 \
--learning_rate 9e-6 \
--dataset 'Anthropic/hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward,lmsys/chatbot_arena_conversations,openai/webgpt_comparisons' \
--dataset_probs '0.72,0.08,0.12,0.08' \
--load_model './ckpt/7b_llama/sft_model.pt'
# --gradient_checkpointing [BUG] https://github.com/microsoft/DeepSpeedExamples/issues/458
注意:这个配置是4卡的 如果你使用单卡。请增加选项 --adam_offload,并且设置 zero_stage 2,但速度会比较慢。
RLHF
最后就是我们的 PPO 训练了,PPO的训练prompts同样来自于各种SFT/RM 数据集抽取prompts的混合体。
https://github.com/OpenLLMAI/OpenLLaMA2/blob/main/examples/scripts/train_ppo_llama.sh
deepspeed ../train_ppo.py \
--pretrain "meta-llama/Llama-2-7b-hf" \
--critic_pretrain "meta-llama/Llama-2-7b-hf" \
--reward_model_path "./ckpt/7b_llama/rm_model.pt" \
--sft_model_path "./ckpt/7b_llama/sft_model.pt" \
--save_path './ckpt/7b_llama' \
--micro_train_batch_size 1 \
--train_batch_size 128 \
--micro_rollout_batch_size 1 \
--rollout_batch_size 1024 \
--max_epochs 1 \
--prompt_max_len 1024 \
--generate_max_len 1024 \
--zero_stage 2 \
--bf16 \
--actor_learning_rate 5e-7 \
--critic_learning_rate 9e-6 \
--inference_tp_size 1 \
--init_kl_coef 0.01 \
--prompt_data 'yahma/alpaca-cleaned,Dahoas/full-hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward' \
--prompt_data_probs '0.3,0.6,0.1' \
--normalize_reward \
--adam_offload \
--gradient_checkpointing
该配置支持 1 卡到 N 卡
Example
最后给出一个RLHF训练前后的对比的例子
Ocra 500k SFT 模型
问题 1
Human: Please introduce the game GTA5
Assistant: GTA5, or Grand Theft Auto V, is an action-adventure video game developed by Rockstar North and published by Rockstar Games. It was released in 2013 for the PlayStation 3 and Xbox 360, and later in 2014 for the PlayStation 4 and Xbox One. The game is set in the fictional city of Los Santos, based on Los Angeles, and follows the story of three criminals as they plan and execute heists.
GTA5 is an open-world game, meaning that players can explore the game's vast and detailed world at their leisure, completing missions, engaging in side activities, and interacting with other characters. The game features an extensive array of vehicles, from cars to boats to planes, and players can use them to travel around the city and engage in various forms of criminal activity.
GTA5 is known for its immersive gameplay, stunning graphics, and rich storytelling. It has received critical acclaim and has become one of the best-selling video games of all time, with over 150 million copies sold worldwide.
Ocra 500k SFT+PPO 训练后的模型
Human: Please introduce the game GTA5
Assistant: 1. **Game Name:** Grand Theft Auto V (GTA V)
2. **Platforms:** GTA V is available on PlayStation 3 and Xbox 360, and later released on PlayStation 4, Xbox One, and Microsoft Windows.
3. **Genre:** Action-adventure game
4. **Gameplay:** GTA V is an open-world game where players can explore the fictional city of Los Santos and its surrounding areas, completing missions, engaging in side activities, and interacting with non-player characters. The game features a third-person perspective and allows players to freely roam the game world, engage in combat, and complete objectives.
5. **Storyline:** The game follows the story of three protagonists, Michael De Santa, Franklin Clinton, and Trevor Philips, as they become embroiled in a series of heists and other criminal activities. The game's storyline is considered one of the strongest aspects of GTA
6. **Graphics and Sound:** GTA V features stunning graphics and realistic character animations, making the game world feel alive and immersive. The soundtrack and voice acting are also highly praised, with a diverse and engaging musical score that complements the game's action-packed gameplay.
7. **Reception:** GTA V received critical acclaim upon its release, with many praising its storyline, gameplay, graphics, and sound. It is considered one of the best video games of all time and has sold over 150 million copies worldwide, making it one of the best-selling video games of all time.
In summary, GTA V is an action-adventure game set in a massive open world, with a complex and engaging storyline, stunning graphics, and highly praised soundtrack and voice acting. It is considered one of the greatest video games of all time and continues to be a popular choice among gamers worldwide.
问题2
Ocra 500k SFT 模型
Human: Which GPU is better, AMD or NVIDIA?
Assistant:
The performance and suitability of a GPU (Graphics Processing Unit) depend on various factors, including the specific task or application being used, the model and generation of the GPU, and the system it is being used in. Both AMD and NVIDIA are well-known and reputable manufacturers of GPUs, and they both have a wide range of models and generations.
In general, the performance of a GPU can vary depending on the specific task or application being used. For example, AMD's GPUs are known for their performance in gaming and video editing, while NVIDIA's GPUs are known for their performance in machine learning and deep learning.
It is important to note that the performance of a GPU can also vary depending on the generation of the GPU. Newer generations of GPUs tend to have better performance and features than older generations.
Ultimately, the best GPU for you will depend on your specific needs and the tasks you plan to use it for. It is recommended to research and compare the specific models and generations of AMD and NVIDIA GPUs to determine which one is the best fit for your needs.
Ocra 500k SFT+PPO 训练后的模型
Human: Which GPU is better, AMD or NVIDIA?
Assistant: 1. **Performance:** In terms of raw performance, NVIDIA typically outperforms AMD in most cases, especially in gaming and video editing.
2. **Price:** NVIDIA GPUs tend to be more expensive than AMD GPUs, although this can vary depending on the specific model and its features.
3. **Architecture:** NVIDIA's GPU architecture is based on the CUDA (Compute Unified Device Architecture) platform, which allows for more efficient processing of parallel computing tasks. AMD's GPU architecture, on the other hand, is based on the Stream Compute architecture, which is designed for general-purpose computing.
4. **Application Support:** NVIDIA's GPUs have better support for certain applications, particularly in the fields of deep learning and artificial intelligence. AMD's GPUs, however, have better support for certain games and video editing software.
In summary, NVIDIA's GPUs tend to be more powerful and efficient for certain tasks, while AMD's GPUs are better suited for other tasks and have better support for certain applications. Ultimately, the choice between the two depends on the specific needs and preferences of the user.
可以看出来PPO在helpfulness人类偏好数据集上主要提升了行文排版和内容丰富度,以及问题相关性
而且 Ocra 应该本身就是质量比较高的 SFT 数据集了
后续开发计划
后续我们准备引入 Ray 甚至推理框架以支持 70B 模型的 RLHF 训练
还有 qlora 等技术,并且完善 llama2 rlhf 的整个流程含 rejection sampling,直到形成一个完整的训练闭环
我们也在打造 RLHF 中文数据集,欢迎感兴趣的同学一起参与
组织链接 OpenLLMAI