機(jī)器之心報(bào)道

編輯:陳陳、杜偉

  • 大語(yǔ)言模型的推理能力,不再是 AR(自回歸)的專屬。擴(kuò)散模型現(xiàn)在也能「動(dòng)腦子」,新框架 d1 讓它們學(xué)會(huì)了解數(shù)學(xué)、懂邏輯、會(huì)思考。

當(dāng)前,強(qiáng)化學(xué)習(xí)(RL)方法在最近模型的推理任務(wù)上取得了顯著的改進(jìn),比如 DeepSeek-R1、Kimi K1.5,顯示了將 RL 直接用于基礎(chǔ)模型可以取得媲美 OpenAI o1 的性能。

不過(guò),基于 RL 的后訓(xùn)練進(jìn)展主要受限于自回歸的大語(yǔ)言模型(LLM),它們通過(guò)從左到右的序列推理來(lái)運(yùn)行。

與此同時(shí),離散擴(kuò)散大語(yǔ)言模型(dLLM)成為有潛力的語(yǔ)言建模的非自回歸替代。不像以因果方式逐 token 生成文本的自回歸模型那樣,dLLM 通過(guò)迭代去噪過(guò)程生成文本,在多步驟操作中優(yōu)化序列的同時(shí)并通過(guò)雙向注意力利用過(guò)去和未來(lái)的上下文。其中,LLaDA 等開放的掩碼 dLLM 實(shí)現(xiàn)了媲美同尺寸自回歸模型的性能,而 Mercury 等閉源 dLLM 進(jìn)一步展現(xiàn)了出色的推理延遲。

然而,頂級(jí)的開源 dLLM 并沒有使用 RL 后訓(xùn)練,使得這一有潛力的研究方向還有很大的挖掘空間。這一范式轉(zhuǎn)變引出了重要的問題:RL 后訓(xùn)練如何在非自回歸上下文中高效地實(shí)現(xiàn)?

RL 算法適應(yīng)掩碼 dLLM 面臨一些獨(dú)特的挑戰(zhàn),原因在于自回歸模型采用的已有方法(如 PPO、GRPO)通過(guò)計(jì)算生成序列的對(duì)數(shù)概率來(lái)估計(jì)和優(yōu)化策略分布,導(dǎo)致無(wú)法直接應(yīng)用于 dLLM。雖然這種計(jì)算在自回歸模型中通過(guò)序列因式分解很容易實(shí)現(xiàn),但 dLLM 由于它們的迭代、非序列生成過(guò)程而缺乏這種自然分解。

為了解決這些問題,來(lái)自 UCLA 和 Meta AI 的研究者提出了一個(gè)兩階段后訓(xùn)練框架 d1,從而可以在掩碼 dLLM 中進(jìn)行推理。在第一階段,模型在高質(zhì)量推理軌跡中進(jìn)行監(jiān)督微調(diào);在第二即 RL 階段,研究者引入了用于掩碼 dLLM 的新穎策略梯度方法 diffu-GRPO,它利用提出的高效一步(one-step)對(duì)數(shù)概率估計(jì)在 GRPO 的基礎(chǔ)上創(chuàng)建。

研究者表示,他們的估計(jì)器利用了隨機(jī)提示詞掩碼,作為策略優(yōu)化的一種正則化,使得可以擴(kuò)展 per batch 的梯度更新數(shù)量并減少 RL 訓(xùn)練所需的在線生成數(shù)量。這將極大地降低計(jì)算時(shí)間。

打開網(wǎng)易新聞 查看精彩圖片

  • 論文標(biāo)題:d1: Scaling Reasoning in Diffusion Large Language Models via Reinforcement Learning
  • 論文地址:https://arxiv.org/pdf/2504.12216
  • 項(xiàng)目主頁(yè):https://dllm-reasoning.github.io/
  • GitHub 地址:https://github.com/dllm-reasoning/d1

在實(shí)驗(yàn)部分,研究者使用 LLaDA-8B-Instruct 作為基礎(chǔ)模型實(shí)例化 d1。他們將 d1-LLaDA 的性能與基礎(chǔ) LLaDA 模型以及僅使用 SFT 和僅使用 diffu-GRPO 訓(xùn)練的 LLaDA 模型進(jìn)行比較。結(jié)果表明,d1 在四個(gè)數(shù)學(xué)和邏輯推理基準(zhǔn)測(cè)試中始終優(yōu)于基礎(chǔ)模型,如下圖 1 所示。d1-LLaDA 同樣優(yōu)于僅使用 SFT 方法和僅使用 diffu-GRPO 方法的模型。

打開網(wǎng)易新聞 查看精彩圖片

方法概覽

d1 是一個(gè)兩階段框架,通過(guò)依次結(jié)合監(jiān)督微調(diào)(SFT)和在線強(qiáng)化學(xué)習(xí)(RL)來(lái)增強(qiáng)預(yù)訓(xùn)練掩碼 dLLMs 的推理性能。

其中,在線強(qiáng)化學(xué)習(xí)(特別是 GRPO 算法)已被證明能有效提升離線訓(xùn)練語(yǔ)言模型的性能。然而,GRPO 的學(xué)習(xí)策略并不能直接泛化到 dLLMs。

GRPO 的目標(biāo)函數(shù)(如公式 3 所示)需要同時(shí)計(jì)算當(dāng)前策略 π_θ 和舊策略 π_θold 在以下兩個(gè)層面的(對(duì)數(shù))似然比:

  1. token 層面(用于優(yōu)勢(shì)權(quán)重計(jì)算);
  2. 序列層面(用于反向 KL 散度項(xiàng))。

核心問題在于:研究者需要高效計(jì)算 dLLMs 生成內(nèi)容的逐 token 對(duì)數(shù)概率和序列對(duì)數(shù)概率。

自回歸(AR)模型,如 Transformer,直接對(duì)每個(gè) token 的對(duì)數(shù)概率進(jìn)行建模,并且可以通過(guò)鏈?zhǔn)椒▌t使用一次前向傳遞輕松計(jì)算出序列級(jí)別的對(duì)數(shù)概率

同樣,KL 項(xiàng)可以分解為。

打開網(wǎng)易新聞 查看精彩圖片

與 AR 模型不同,dLLMs 不遵循序列對(duì)數(shù)概率的順序分解。同時(shí),每個(gè) token 的對(duì)數(shù)概率計(jì)算成本也很高,因?yàn)榻獯a過(guò)程中需要多次調(diào)用掩碼預(yù)測(cè)器 f_θ?;诖?,該研究提出了一個(gè)高效的對(duì)數(shù)概率估計(jì)器。

對(duì)于序列對(duì)數(shù)概率,該研究使用均場(chǎng)近似方法,將其分解為獨(dú)立的每個(gè) token 對(duì)數(shù)概率的乘積。

對(duì)于每個(gè) token 的對(duì)數(shù)概率,該研究引入了一種估計(jì)方法,該方法僅調(diào)用一次 f_θ。

基于新引入的對(duì)數(shù)概率估計(jì)器,該研究將 GRPO 擴(kuò)展到掩碼 dLLMs,推導(dǎo)出 diffu-GRPO 的損失函數(shù)。

打開網(wǎng)易新聞 查看精彩圖片

算法如下圖所示。

打開網(wǎng)易新聞 查看精彩圖片

實(shí)驗(yàn)結(jié)果

表 1 報(bào)告了基線模型 LLaDA-8B-Instruct 與采用不同后訓(xùn)練優(yōu)化方案的模型,在四項(xiàng)任務(wù)上的零樣本性能對(duì)比。

打開網(wǎng)易新聞 查看精彩圖片

圖 3 繪制了有效 token 的平均數(shù)量:

打開網(wǎng)易新聞 查看精彩圖片

基于實(shí)驗(yàn),該研究得出以下主要發(fā)現(xiàn):

diffu-GRPO 在所有 12 種設(shè)置中都一致優(yōu)于基礎(chǔ)的 LLaDA 和 SFT(監(jiān)督式微調(diào))。diffu-GRPO 和 SFT 都相較于 LLaDA-8B-Instruct 基線有所提升,但 diffu-GRPO 顯示出更持續(xù)且幅度更大的增益。具體來(lái)說(shuō),diffu-GRPO 在所有 12 種設(shè)置中都優(yōu)于 LLaDA-8B-Instruct 和 SFT,而 SFT 僅在其中的 7 種設(shè)置中優(yōu)于 LLaDA-8B-Instruct,這表明diffu-GRPO 相比于單獨(dú)的 SFT 實(shí)現(xiàn)了更強(qiáng)的整體性能提升。

LLaDA+diffu-GRPO 在所有設(shè)置中都優(yōu)于基礎(chǔ)的 LLaDA-8B-Instruct 模型,而 d1-LLaDA 在每種情況下都超過(guò)了 LLaDA+SFT。這表明,無(wú)論初始化是來(lái)自預(yù)訓(xùn)練模型還是經(jīng)過(guò) SFT 調(diào)整的檢查點(diǎn),diffu-GRPO 都能提供可靠的性能提升。

d1 訓(xùn)練方案實(shí)現(xiàn)了最顯著的性能提升。通過(guò)先進(jìn)行監(jiān)督微調(diào)(SFT)、再結(jié)合 diffu-GRPO 訓(xùn)練所形成的 d1-LLaDA 模型,產(chǎn)生了超越單一方法的疊加增益。這種組合式方法在 12 個(gè)實(shí)驗(yàn)設(shè)置中有 11 項(xiàng)優(yōu)于純 diffu-GRPO 方案,表明兩個(gè)訓(xùn)練階段存在協(xié)同效應(yīng)。

定性結(jié)果表明,在 SFT 和 d1-LLaDA 生成中出現(xiàn)了頓悟時(shí)刻。盡管與 LLaDA-8B-Instruct 相比,生成序列長(zhǎng)度為 128 和 256 的性能隨著 SFT、diffu-GRPO 和 d1 有所提高,但從質(zhì)的方面看,在生成的推理軌跡中并未觀察到顯著差異。然而當(dāng)序列長(zhǎng)度達(dá)到 512 時(shí),該研究開始觀察到 SFT 和 d1-LLaDA 模型展現(xiàn)出兩種關(guān)鍵能力:自我修正機(jī)制和回溯行為。