
機(jī)器之心報(bào)道
編輯:杜偉
最近,DeepSeek-R1 和 OpenAI o1/03 等推理大模型在后訓(xùn)練階段探索了長(zhǎng)度擴(kuò)展(length scaling),通過強(qiáng)化學(xué)習(xí)(比如 PPO、GPRO)訓(xùn)練模型生成很長(zhǎng)的推理鏈(CoT),并在奧數(shù)等高難度推理任務(wù)上取得了顯著的效果提升。
受此啟發(fā),研究人員開始探索預(yù)訓(xùn)練階段的長(zhǎng)度擴(kuò)展,已有方法包括在序列中插入文本、插入潛在向量(如 Coconut)、復(fù)用中間層隱藏狀態(tài)(如 CoTFormer)以及將中間隱藏狀態(tài)映射為概念(如 COCOMix)。不過,這些方法普遍存在問題,比如需要更大的 KV 緩存導(dǎo)致推理慢 / 占內(nèi)存多。
本文中,來自 ByteDance Seed 團(tuán)隊(duì)的研究者提出了更簡(jiǎn)單的方法:直接重復(fù)輸入 tokens(1/2/3/4 次),不做中間層處理。他們觀察到了訓(xùn)練損失和模型性能隨重復(fù)倍數(shù)擴(kuò)展的趨勢(shì),如下圖 1a 和 1b 所示。但是,直接重復(fù) tokens 也帶來了新問題,包括 KV 緩存規(guī)模線性增加,內(nèi)存壓力大;預(yù)填充時(shí)間超線性增加;解碼延遲變長(zhǎng)。這些都是實(shí)現(xiàn)預(yù)訓(xùn)練長(zhǎng)度擴(kuò)展需要重點(diǎn)解決的挑戰(zhàn)。

- 論文標(biāo)題:Efficient Pretraining Length Scaling
- arXiv 地址:https://arxiv.org/pdf/2504.14992
研究者提出了一種推理友好的新穎長(zhǎng)度擴(kuò)展方法,核心是 PHD-Transformer(Parallel Hidden Decoding Transformer),它保持了與原始 transformer 相同的 KV 緩存大小,同時(shí)實(shí)現(xiàn)有效的長(zhǎng)度擴(kuò)展。PHD-Transformer 通過創(chuàng)新的 KV 緩存管理策略實(shí)現(xiàn)了這些能力。
具體來講,研究者將第一個(gè) token 表示原始 token,將重復(fù)的 token 表示為解碼 token。同時(shí)僅保留從原始 token 生成的 KV 緩存來用于長(zhǎng)距離依賴建模,并在隱藏解碼 token 用于下一個(gè) token 預(yù)測(cè)之后丟棄它們的 KV 緩存。因此,PHD-Transformer 提供了與原始 transformer 相同的 KV 緩存,同時(shí)相較于簡(jiǎn)單的 token 重復(fù)實(shí)現(xiàn)了顯著的推理加速(如圖 1d 所示)。

研究者還注意到,在 PHD-SWA 中,隱藏解碼 token 的 KV 緩存表現(xiàn)出了順序依賴關(guān)系,這導(dǎo)致預(yù)填充時(shí)間呈線性增長(zhǎng)。為了解決這個(gè)問題,研究者提出了逐塊滑動(dòng)窗口注意力 —— PHD-CSWA,從而限制了每個(gè)塊內(nèi)的順序依賴關(guān)系。
因此,得益于只有最后一個(gè)塊的預(yù)填充時(shí)間呈線性增長(zhǎng),PHD-CSWA 顯著縮短了預(yù)填充時(shí)間(如圖 1c 所示)。

方法概覽

研究者在推理過程中實(shí)現(xiàn)了與原始 Transformer 相同的 KV 緩存大小和內(nèi)存訪問模式。雖然需要 K 次 FLOP,但這些計(jì)算可以并行處理,從而在內(nèi)存受限的推理場(chǎng)景中最大限度地降低延遲開銷。該架構(gòu)的核心優(yōu)勢(shì)在于原始 token 和隱藏解碼 token 之間的解耦。在預(yù)填充期間,只有原始 token 需要計(jì)算。
這種設(shè)計(jì)確保預(yù)填充時(shí)間與原始 Transformer 相同,并且無論擴(kuò)展因子 K 如何變化,預(yù)填充時(shí)間都保持不變。而對(duì)于損失計(jì)算,研究者僅使用 token 的最終副本進(jìn)行下一個(gè) token 的預(yù)測(cè)??傊褂?token 的第一個(gè)副本進(jìn)行 KV 緩存生成,使用 token 的最后一個(gè)副本進(jìn)行下一個(gè) token 的預(yù)測(cè)。

內(nèi)核設(shè)計(jì)


PHD-SWA 和 PHD-CSWA
與簡(jiǎn)單的 token 重復(fù)相比,PHD-Transformer 在保持原始 KV 緩存大小的同時(shí)實(shí)現(xiàn)了長(zhǎng)度擴(kuò)展。然而通過經(jīng)驗(yàn)觀察到,為隱藏解碼 token 保留一些 KV 緩存可以帶來顯著的性能提升。因此,為了在保持效率的同時(shí)獲得這些優(yōu)勢(shì),研究者引入了 PHD-SWA,將滑動(dòng)窗口注意力限制在 W 個(gè)先前的隱藏解碼 token 上。

雖然 PHD-SWA 滑動(dòng)窗口方法提升了模型性能,但由于隱藏解碼 token 的 KV 緩存中存在順序依賴關(guān)系,它會(huì)產(chǎn)生 K 倍的預(yù)填充開銷。為了解決這個(gè)問題,研究者引入了 PHD-CSWA,它可以在獨(dú)立的塊內(nèi)處理注意力。
如下圖 4 所示,PHD-CSWA 將滑動(dòng)窗口注意力限制在單個(gè)塊內(nèi)運(yùn)行。這種架構(gòu)創(chuàng)新將額外的預(yù)填充開銷減少到最終塊內(nèi)的 K 次重復(fù),而不是整個(gè)序列重復(fù),這使得額外的計(jì)算成本幾乎可以忽略不計(jì),同時(shí)保留了局部注意力模式的優(yōu)勢(shì)。

實(shí)驗(yàn)結(jié)果
在實(shí)驗(yàn)中,研究者使用 OLMo2 作為代碼庫,并在 ARC、HellaSwag、PIQA、Winogrande、MMLU 和 CommonsenseQA 等公開基準(zhǔn)測(cè)試集上進(jìn)行了評(píng)估。
訓(xùn)練細(xì)節(jié):研究者使用 1.2B 參數(shù)規(guī)模的模型,它是一個(gè) 16 層的密集模型。每個(gè) token 的隱藏層維數(shù)設(shè)置為 2048,F(xiàn)FN 層的隱藏層大小設(shè)置為 16384。同時(shí)使用組查詢注意力 (Group-Query Attention,GQA),它包含 32 個(gè)查詢頭和 8 個(gè)鍵 / 值頭,每個(gè)頭的隱藏層維數(shù)設(shè)置為 64。研究者使用 500B 個(gè) token 訓(xùn)練該模型。
對(duì)于本文提出的 PHD 系列設(shè)置,研究者預(yù)訓(xùn)練了以下兩種 PHD-CSWA 變體:
- PHD-CSWA-2-16-32,其中訓(xùn)練 token 重復(fù)兩次。保留一個(gè)包含 16 個(gè) token 的局部窗口,并將塊大小設(shè)置為 32 個(gè) token。
- PHD-CSWA-3-16-32,其中訓(xùn)練 token 重復(fù)三次。局部窗口大小和塊大小與 PHD-CSWA-2-16-32 的設(shè)置相同。
PHD-CSWA 在各個(gè)基準(zhǔn)測(cè)試中均實(shí)現(xiàn)了持續(xù)的性能提升。下圖 5 中展示了訓(xùn)練曲線,下表 1 中展示了主要結(jié)果。本文提出的 PHD-CSWA-2-16-32 在這些基準(zhǔn)測(cè)試中平均實(shí)現(xiàn)了 1.5% 的準(zhǔn)確率提升,訓(xùn)練損失降低了 0.025;而 PHD-CSWA-3-16-32 在這些基準(zhǔn)測(cè)試中平均實(shí)現(xiàn)了 2.0% 的準(zhǔn)確率提升,訓(xùn)練損失降低了 0.034。


研究者還分析了 PHD 和 PHD-SWA 的擴(kuò)展性能,以分析擴(kuò)展解碼計(jì)算的性能。 訓(xùn)練細(xì)節(jié):使用相同的 550M 模型配置,將窗口大小 W 設(shè)置為 16,并在 {2, 3, 5} 范圍內(nèi)改變擴(kuò)展因子 K。對(duì)于局部窗口大小,研究者在所有實(shí)驗(yàn)中都將窗口大小設(shè)置為 16。
PHD-SWA 的性能在增加擴(kuò)展因子時(shí)有效擴(kuò)展。如下圖 8 所示,使用固定窗口大小時(shí),損失曲線和下游性能會(huì)隨著 token 重復(fù)次數(shù)而有效擴(kuò)展。通過將擴(kuò)展因子設(shè)置為 5,可以實(shí)現(xiàn)接近 0.06 的損失降低,同時(shí)顯著提升下游性能。
下表 2 中的定量結(jié)果表明,當(dāng)擴(kuò)展至 K = 5 時(shí),所有基準(zhǔn)測(cè)試的平均準(zhǔn)確率提高了 1.8%,這證實(shí)了本文的方法在更激進(jìn)的擴(kuò)展方面仍然有效。


更多實(shí)驗(yàn)結(jié)果請(qǐng)參閱原論文。
熱門跟貼