
選自oxen.ai
作者:Greg Schoeninger
編譯:陳陳、澤南
RTX 3080 移動(dòng)版能訓(xùn)練哪種大模型?本文為那些 GPU 資源有限時(shí)使用 GRPO 訓(xùn)練的開發(fā)者提供了寶貴的指導(dǎo)。
自 DeepSeek-R1 發(fā)布以來,群組相對(duì)策略優(yōu)化(GRPO)因其有效性和易于訓(xùn)練而成為大型語言模型強(qiáng)化學(xué)習(xí)的熱門話題。R1 論文展示了如何使用 GRPO 從遵循 LLM(DeepSeek-v3)的基本指令轉(zhuǎn)變?yōu)橥评砟P停―eepSeek-R1)。
GRPO 是一種在線學(xué)習(xí)算法(online learning algorithm),它通過使用訓(xùn)練過程中由訓(xùn)練模型自身生成的數(shù)據(jù)來進(jìn)行迭代改進(jìn)。GRPO 的目標(biāo)是最大化生成補(bǔ)全(completions)的優(yōu)勢(shì)函數(shù)(advantage),同時(shí)確保模型保持在參考策略(reference policy)附近。

本文的目的是幫你節(jié)省一些時(shí)間,讓你根據(jù)硬件預(yù)算選擇合適的模型大小。在開始微調(diào)時(shí),你必須做出的重要決定是選擇模型大小,以及你是執(zhí)行完全微調(diào)還是參數(shù)高效微調(diào)(PEFT)。
文章作者來自 AI 公司 Oxen.ai 的 CEO Greg Schoeninger。

原文鏈接:https://www.oxen.ai/blog/grpo-vram-requirements-for-the-gpu-poor
作者表示,他發(fā)現(xiàn) trl 庫中已經(jīng)有一個(gè)易于使用的 GRPO 實(shí)現(xiàn),便立刻開始了訓(xùn)練,使用的硬件是配備了 16GB 顯存的 Nvidia GeForce RTX 3080 的小型筆記本電腦。正如大家可能遇到的問題,作者發(fā)現(xiàn)示例代碼中的參數(shù)設(shè)置導(dǎo)致了一個(gè)巨大的顯存不足(OOM,out of memory )錯(cuò)誤。
- torch
- OutOfMemoryError
- CUDA
- out
- of memory
- Tried
- to allocate
- 1.90
- GiB
- GPU
- 0
- has a total capacity of
- GiB
- of which
- 1.28
- GiB
- is
- free
- Including
- non
- PyTorch
- memory
- this
- process has
- GiB
- memory
- in
- use
- Of
- the allocated memory
- GiB
- is
- allocated
- by
- PyTorch
- and
- 2.41
- GiB
- is
- reserved
- by
- PyTorch
- but unallocated
- If
- reserved but unallocated memory
- is
- large
- try
- setting PYTORCH_CUDA_ALLOC_CONF
- expandable_segments
- True
- to avoid fragmentation
- See
- documentation
- for
- Memory
- Management
- //pytorch.org/docs/stable/notes/cuda.html#environment-variables)
實(shí)際使用情況
作者表示,他們進(jìn)行了一系列實(shí)驗(yàn),以確定訓(xùn)練各種大小的模型所需的顯存(VRAM)要求。參數(shù)數(shù)量從 5 億到 140 億不等,他們比較了權(quán)重的完全微調(diào)與參數(shù)高效微調(diào)(使用 LoRA),所有訓(xùn)練運(yùn)行都在英偉達(dá) H100 上完成,因此這里的 OOM 意味著 >80GB 的 VRAM。

在表格中,你可以找到 GSM8K 數(shù)據(jù)集上訓(xùn)練的前 100 步中的峰值內(nèi)存使用情況。用于實(shí)驗(yàn)的模型是:

所有實(shí)驗(yàn)均使用 Shadeform 的 GPU 市場(chǎng)完成,因此每次實(shí)驗(yàn)只需要花費(fèi)幾美元 H100。
實(shí)驗(yàn)結(jié)果表明,內(nèi)存需求隨著模型大小和訓(xùn)練方式的不同而顯著變化。例如,全參數(shù)微調(diào)比 PEFT 需要更多的內(nèi)存。
為什么 GRPO 對(duì)內(nèi)存需求較高
這要從 GRPO 的原理說起,這是它的流程圖。

GRPO 對(duì)內(nèi)存需求較高的原因在于,其內(nèi)部涉及多個(gè)模型,并且在訓(xùn)練數(shù)據(jù)中每個(gè)查詢會(huì)產(chǎn)生多個(gè)輸出。上圖中的策略模型、參考模型和獎(jiǎng)勵(lì)模型各自都是一個(gè)需要進(jìn)行推理的 LLM。(盡管從技術(shù)上講,獎(jiǎng)勵(lì)模型可能不需要參數(shù)化,可以只是一個(gè) Python 函數(shù)或正則表達(dá)式,但不影響 GRPO 對(duì)內(nèi)存的高需求。)
為什么 8-Bit 優(yōu)化和梯度檢查點(diǎn)有助于減少內(nèi)存占用?
通常來講,訓(xùn)練一個(gè)大型語言模型需要在內(nèi)存中存儲(chǔ)三種主要類型的信息:模型參數(shù)、模型學(xué)習(xí)所需的梯度、優(yōu)化器的跟蹤數(shù)據(jù)。
對(duì)上述內(nèi)容我們可以這樣理解:如果模型的參數(shù)占用了 X 的空間,那么梯度也會(huì)占用大約相同的空間。然后,像 AdamW 這樣的優(yōu)化器需要更多的空間,因?yàn)樗鼈兙拖褚粋€(gè)記錄員,跟蹤最近的更新歷史,以便更好地決定未來的優(yōu)化。
為了減輕這種內(nèi)存負(fù)擔(dān),通常采用兩種技術(shù):
- 首先,可以使用像 AdamW 這樣的 8-bit 優(yōu)化器版本,它們能更高效地存儲(chǔ)跟蹤數(shù)據(jù),同時(shí)仍保持良好的性能 —— 類似于壓縮照片可以節(jié)省空間,同時(shí)保留大部分圖像質(zhì)量;
- 其次,使用梯度檢查點(diǎn)技術(shù),這就像在訓(xùn)練過程中拍攝快照,而不是記錄所有內(nèi)容。雖然這會(huì)使訓(xùn)練速度減慢約 20-30%,但它顯著減少了內(nèi)存使用。
結(jié)合這些技術(shù),即使對(duì) GPU 資源有限的人來說,也能夠訓(xùn)練更大的模型。
代碼示例
像 trl 這樣的庫已經(jīng)開始支持 GRPO,使得微調(diào)由 transformers 構(gòu)成的 LLM 變得非常簡(jiǎn)單。代碼也非常簡(jiǎn)潔,只需將訓(xùn)練器替換為 GRPOTrainer 并定義一些獎(jiǎng)勵(lì)即可。GRPO 的最小代碼量大約只有 99 行,如果你使用的是像 meta-llama/Llama-3.2-1B-Instruct 這樣的小型模型和像 openai/GSM8K 這樣的數(shù)據(jù)集,可以非??焖俚貑?dòng)。
trl 項(xiàng)目地址:https://github.com/huggingface/trl?ref=ghost.oxen.ai
- import
- torch
- from
- datasets
- import
- load_dataset
- Dataset
- from
- transformers
- import
- AutoTokenizer
- AutoModelForCausalLM
- from
- trl
- import
- GRPOConfig
- GRPOTrainer
- import
- re
- SYSTEM_PROMPT
- Respond in the following format:
- def
- extract_hash_answer
- text
- str
- str
- None
- if
- "####"
- not
- in
- text
- return
- None
- return
- text
- split
- "####"
- 1
- strip
- def
- get_gsm8k_questions
- split
- "train"
- Dataset
- data
- load_dataset
- 'openai/gsm8k'
- 'main'
- split
- data
- data
- map
- lambda
- 'prompt'
- 'role'
- 'system'
- 'content'
- SYSTEM_PROMPT
- },
- 'role'
- 'user'
- 'content'
- 'question'
- ],
- 'answer'
- extract_hash_answer
- 'answer'
- return
- data
- def
- extract_xml_answer
- text
- str
- str
- answer
- text
- split
- 1
- answer
- answer
- split
- ""
- 0
- return
- answer
- strip
- def
- format_reward_func
- completions
- kwargs
- list
- float
- """Reward function that checks if the completion has a specific format."""
- pattern
- r
- "^\n\n$"
- \n.*?\n
- \n.*?\n
- responses
- completion
- 0
- "content"
- for
- completion
- in
- completions
- matches
- re
- match
- pattern
- r
- for
- r
- in
- responses
- return
- 0.5
- if
- match
- else
- 0.0
- for
- match
- in
- matches
- def
- accuracy_reward_func
- prompts
- completions
- answer
- kwargs
- list
- float
- """Reward function that extracts the answer from the xml tags and compares it to the correct answer."""
- responses
- completion
- 0
- 'content'
- for
- completion
- in
- completions
- extracted_responses
- extract_xml_answer
- r
- for
- r
- in
- responses
- return
- 2.0
- if
- r
- a
- else
- 0.0
- for
- r
- a
- in
- zip
- extracted_responses
- answer
- def
- main
- dataset
- get_gsm8k_questions
- model_name
- "meta-llama/Llama-3.2-1B-Instruct"
- model
- AutoModelForCausalLM
- from_pretrained
- model_name
- torch_dtype
- torch
- bfloat16
- attn_implementation
- "flash_attention_2"
- device_map
- None
- to
- "cuda"
- tokenizer
- AutoTokenizer
- from_pretrained
- model_name
- tokenizer
- pad_token
- tokenizer
- eos_token
- training_args
- GRPOConfig
- output_dir
- "output"
- learning_rate
- 5e-6
- adam_beta1
- 0.9
- adam_beta2
- 0.99
- weight_decay
- 0.1
- warmup_ratio
- 0.1
- lr_scheduler_type
- 'cosine'
- logging_steps
- 1
- bf16
- True
- per_device_train_batch_size
- 1
- gradient_accumulation_steps
- 4
- num_generations
- 4
- max_prompt_length
- 256
- max_completion_length
- 786
- num_train_epochs
- 1
- save_steps
- 100
- save_total_limit
- 1
- max_grad_norm
- 0.1
- log_on_each_node
- False
- trainer
- GRPOTrainer
- model
- model
- processing_class
- tokenizer
- reward_funcs
- format_reward_func
- accuracy_reward_func
- ],
- args
- training_args
- train_dataset
- dataset
- trainer
- train
- if
- __name__
- "__main__"
- main
Num Generations 有什么用
Num Generations 是一個(gè)超參數(shù),它決定了我們將在訓(xùn)練數(shù)據(jù)中對(duì)每個(gè)查詢采樣多少個(gè)補(bǔ)全。然而,這會(huì)顯著增加 VRAM 的消耗。

目前有一個(gè)開放的 GitHub 問題,可能會(huì)幫助解決內(nèi)存瓶頸問題,可以參考如下鏈接
地址:https://github.com/huggingface/trl/issues/2709?ref=ghost.oxen.ai
對(duì)于 num_completions=8,16,64 (DeepSeekMath 論文使用的 64),作者表示,不用再次計(jì)算上述所有值,而是使用了 1B 參數(shù)模型進(jìn)行了測(cè)試,以顯示內(nèi)存增長(zhǎng)。不過,作者還是建議大家在內(nèi)存瓶頸得到修復(fù)之前使用 num_generations=4,也能獲得不錯(cuò)的性能。

影響 VRAM 的一些因素
要對(duì)所有影響顯存(VRAM)使用的因素進(jìn)行全面的超參數(shù)驗(yàn)證,需要進(jìn)行大量的實(shí)驗(yàn)。簡(jiǎn)單起見,這里只指出了需要注意的設(shè)置,以及實(shí)驗(yàn)中使用的具體數(shù)值。
- batch_size=1,由于 GRPO 為每個(gè)查詢生成多個(gè)響應(yīng),batch size 會(huì)迅速失控。
- gradient_accumulation_steps=4,優(yōu)化器是另一個(gè)占用大量 VRAM 的地方。此參數(shù)決定了我們將存儲(chǔ)的梯度以幫助優(yōu)化器進(jìn)行其「爬山」過程。
- num_completions=4,DeepSeekMath 論文中使用了 64。這完全超出了有些人的計(jì)算預(yù)算。
- max_prompt_length=256,如果你想訓(xùn)練模型擁有更大上下文的推理能力,將不得不增加 VRAM。GSM8K 的提示相對(duì)較小,適合此測(cè)試。
- max_completion_length=786,同樣,由于計(jì)算注意力的內(nèi)存有限,推理鏈在這里受到限制。上下文或生成的 token 越多,需要的內(nèi)存就越大。
- LoRA target_modules=["q_proj", "k_proj", "o_proj", "up_proj", "down_proj"] 在這方面可以嘗試幾種不同的迭代。target_modules="all-linear" 是一種流行的方式,可以從你的 LoRA 中擠出最多的性能(就準(zhǔn)確性而言)。
對(duì) VRAM 使用的粗略估算
如果你正在使用 FP16 精度進(jìn)行訓(xùn)練,以下是一些簡(jiǎn)單的估算方法,可以幫助你了解內(nèi)存主要用在了哪些地方:
- 模型參數(shù):每個(gè)參數(shù)占用 2 字節(jié)。
- 參考模型參數(shù):每個(gè)參數(shù)占用 2 字節(jié)。
- 梯度:每個(gè)參數(shù)占用 2 字節(jié)。
- 優(yōu)化器狀態(tài):每個(gè)參數(shù)占用 8 字節(jié)。
- 8 位優(yōu)化器:每個(gè)參數(shù)占用 4 字節(jié)。
- PEFT:有助于減少梯度的顯存占用。
最后是關(guān)于準(zhǔn)確率的。作者完成了一個(gè) 10 億參數(shù)的 Llama 3.2 模型的完整訓(xùn)練。在應(yīng)用 GRPO 之前,該模型在保留測(cè)試集上達(dá)到了約 19% 的準(zhǔn)確率,而在經(jīng)過一個(gè)訓(xùn)練周期后,模型的準(zhǔn)確率飆升至約 40.5%。雖然這離 SOTA 水平還差得很遠(yuǎn),但這展示了 GRPO 的強(qiáng)大潛力。
熱門跟貼