機器之心報道
編輯:蛋醬、杜偉
Attention 還在卷自己。
當(dāng)上下文包含大量 Token 時,如何在忽略干擾因素的同時關(guān)注到相關(guān)部分,是一個至關(guān)重要的問題。然而,大量研究表明,標(biāo)準(zhǔn)注意力在這種情況下可能會出現(xiàn)性能不佳的問題。
標(biāo)準(zhǔn)多頭注意力的工作原理是使用點積比較當(dāng)前查詢向量與上下文 Token 對應(yīng)的鍵向量的相似性。與查詢相似的關(guān)鍵字會獲得更高的注意力權(quán)重,隨后其值向量會主導(dǎo)輸出向量。
例如,與「Alice」Token 相對應(yīng)的查詢向量能夠定位上下文中所有提及「Alice」的內(nèi)容。然而,每個注意力權(quán)重只取決于單個關(guān)鍵字和查詢向量(除了歸一化為 1)。
對單個 token 向量相似性的依賴給注意力機制帶來了根本性的限制。在許多情況下,上下文的相關(guān)部分無法通過單個 token 來識別。例如,查找一個同時提到「Alice」和「rabbit」的句子需要查詢向量對這兩個 token 進(jìn)行編碼。用一個注意頭查找「Alice」,再用另一個注意頭查找「rabbit」,可以分別找到這兩個詞,但不足以確定這兩個詞在哪里被同時提及雖然可以通過 Transformer 的層將多個 token 編碼成一個向量,但這需要增加維度,而且模型需要將大量容量用于這項任務(wù)。
在本文中,研究者提出了一種超越「單個 token」瓶頸的新型注意力機制 ——Multi-Token 注意力(MTA),其高層次目標(biāo)是利用多個向量對的相似性來確定注意力必須集中在哪里。
而研究者僅通過對現(xiàn)有注意力機制進(jìn)行簡單的修改去實現(xiàn)這一目標(biāo)。他們設(shè)計了對注意力權(quán)重的卷積運算,該運算在三個維度上運行:鍵、查詢和注意力頭。這就允許其注意力權(quán)重以相鄰鍵、之前的查詢和其他頭為條件。
直觀地說,在上述例子中,MTA 可以先分別查找「Alice」和「rabbit」的提及,然后將這些注意力組合在一起,只關(guān)注兩者都存在的地方。

- 論文:Multi-Token Attention
- 論文鏈接:https://arxiv.org/abs/2504.00927
具體來說,這項研究的亮點在于:
- 研究者首先用一個有趣的玩具任務(wù)進(jìn)行實驗,該任務(wù)揭示了標(biāo)準(zhǔn)注意力的缺陷,并證明 MTA 可以輕松解決這一問題;
- 接下來,研究者通過在標(biāo)準(zhǔn)語言建模任務(wù)中對 1050 億個詞庫的 880M 個參數(shù)模型進(jìn)行預(yù)訓(xùn)練,對本文的方法進(jìn)行了大規(guī)模測試;
- 研究者發(fā)現(xiàn) MTA 在驗證復(fù)雜度和標(biāo)準(zhǔn)基準(zhǔn)任務(wù)方面都有所改進(jìn),而參數(shù)數(shù)量只增加了 0.001%;
- 此外,研究者還在長語境任務(wù)(如 Needle-in-the-Haystack 和 BabiLong)上評估了所生成的模型,結(jié)果發(fā)現(xiàn) MTA 的表現(xiàn)明顯優(yōu)于基線。
方法概覽
如圖 1(右圖)所示,本文提出的「Multi-Token 注意力」由建立在多頭注意力基礎(chǔ)上的三個重要部分組成:鍵 - 查詢卷積、頭混合卷積和帶深度縮放的組歸一化。
研究者提出了鍵 - 查詢卷積,以在頭部內(nèi)組合多個鍵和查詢,并提出了頭卷積,在頭之間共享知識并放大重要信息。最后,研究者應(yīng)用具有深度縮放功能的組歸一化來抵消殘差流,改善梯度流。

鍵 - 查詢卷積(key-query convolution)
對于 pre-softmax 卷積,MTA 在注意力 logit 上進(jìn)行了一個卷積操作,并結(jié)合來自多個查詢和鍵 token 的信息:

鍵和查詢的長度維數(shù)中采用了卷積,同時 batch 和頭維數(shù)保持獨立。更確切地說,從查詢 q_i 到鍵 k_j 的注意力權(quán)重 a_ij 計算如下:

對于鍵,研究者使用指示函數(shù) 1_i≥j?j′將未來鍵歸零。但是,這樣的掩碼太復(fù)雜,無法實現(xiàn)(必須修改卷積 CUDA 內(nèi)核),因此本文提出了一個更簡單的版本,將已有的因果掩碼應(yīng)用了兩次:

對于 post-softmax 卷積,研究者同樣在注意力權(quán)重的頂部進(jìn)行卷積操作:

這使得注意力權(quán)重之間的交互累加而不是相乘。研究者試驗了兩個版本,但默認(rèn)情況下使用 pre-softmax 版本。每個注意力頭都有單獨的 θ 參數(shù),所以它們可以執(zhí)行不同的卷積操作。選擇的內(nèi)核維數(shù)決定了如何將離得遠(yuǎn)的 token 組合在一起。
頭混合卷積(head mixing convolution)
鍵 - 查詢卷積允許從不同的時間步中混合注意力權(quán)重,而研究者進(jìn)一步提出在頭組中使用頭卷積,因此可以將不同頭的注意力權(quán)重組合起來。
具體地,對于大小為 c_h 的頭卷積內(nèi)核,所有頭被分為 M/c_h 個組。在每個組中,研究者使用了不重疊的卷積操作。這樣一來,MTA 不僅允許在每個頭內(nèi)部的多個查詢和鍵向量上調(diào)整注意力權(quán)重,還可以跨頭共享注意力信息。
舉例而言,考慮將所有頭分為兩個組,使內(nèi)核大小為「c_h = 2」。當(dāng)使用上標(biāo)來表示頭指數(shù)時,則 A^1 和 A^2 是來自兩個不同頭的注意力權(quán)重。這時,新的注意力權(quán)重如下:

其中 w_11、w_12、w_21 和 w_22 是內(nèi)核權(quán)重。這里 softmax 之后出現(xiàn)混合,但可以在 softmax 之前混合 logit。

將一切組合起來(putting everything together)
在前文中,研究者引入兩種不同的方式來混合注意力權(quán)重,一是跨鍵 - 查詢時間步,二是跨不同頭。這兩種方式都可以在單個 MTA 模塊中實現(xiàn)。每種方式都有 pre - 和 post-softmax 版本,因此有多種方法將它們組合在一起。如果都采用 pre-softmax 來混合,則可以通過單個 3 維卷積操作來實現(xiàn),如下圖 2 所示。

實驗結(jié)果
研究者在一系列標(biāo)準(zhǔn)和長距離(long-range)依賴任務(wù)上對 MTA 架構(gòu)進(jìn)行了實驗,并與基線進(jìn)行了比較,從「toy」任務(wù)開始。他們使用了鍵 - 查詢卷積 pre-softmax 和頭混合 post-softmax,另有說明除外。
簡單的 toy 任務(wù)
研究者首先測試了 toy 任務(wù),以驗證本文方法相較于標(biāo)準(zhǔn)多頭注意力的有效性。此任務(wù)中為模型提供了一個塊序列,其中每個塊由 N 個隨機字母組成。相比之下,MTA 先是找到了每個問題字母的位置,然后使用卷積操作來增加所有 L 字母一起被發(fā)現(xiàn)的位置的注意力。
結(jié)果如下表 1 所示,如預(yù)期一樣,具有標(biāo)準(zhǔn)多頭注意力的 transformer 解決這項任務(wù)時,即使問題中只有「L = 2」字母,通常也無法找到目標(biāo)塊。相比之下,MTA 以接近零誤差的成功率解決了所有版本的任務(wù)。

大型語言建模
對于語言建模實驗,研究者對 880M 參數(shù)的模型進(jìn)行了預(yù)訓(xùn)練,并比較了 Transformer、DIFF Transformer 和 Transformer with MTA。對于每個模型,他們進(jìn)行了兩次訓(xùn)練,并在下表 2 中提供了平均驗證困惑度。
結(jié)果顯示,經(jīng)過 MTA 訓(xùn)練的模型,在所有驗證數(shù)據(jù)集上均實現(xiàn)了性能提升,即使只在四分之一的層中應(yīng)用鍵 - 查詢卷積,并且要比 DIFF Transformer 的可學(xué)習(xí)參數(shù)更少。此外,使用層 scaling 的組歸一化是一個重要組件,可以為 DIFF Transformer 和 MTA 架構(gòu)提供更優(yōu)越的性能。

接著,研究者在以上相同的六個數(shù)據(jù)集上對模型進(jìn)行了另外 10.5B token 的微調(diào),并將上下文長度從 2048 增加到了 4096。同時將 RoPE 的 θ 值增加到了 50 萬,將權(quán)重衰減變成 0,并將預(yù)熱步驟降為 50,其他參數(shù)與預(yù)訓(xùn)練階段保持一致。結(jié)果表明,使用 MTA 生成的 Transformer 模型在困惑度評估中同樣優(yōu)于新的基線。
在 zero-shot 設(shè)置下,研究者進(jìn)一步評估了模型在一系列流行基準(zhǔn)上的表現(xiàn),結(jié)果如下表 3 所示。經(jīng)過 MTA 訓(xùn)練的模型在大多數(shù)基準(zhǔn)上優(yōu)于基線,并取得了更高的平均分,盡管這些并不是長上下文任務(wù)。

長距離依賴任務(wù) Long-range dependency tasks
此前的研究表明,Transformer 很難找到相關(guān)信息,尤其是在長上下文中。
為了在這種情況下測試 MTA,研究者在三個任務(wù)中對訓(xùn)練有素的模型進(jìn)行了評估: LAMBADA、NeedleIn-A-Haystack 和 BabiLong。所有這些任務(wù)都要求模型幾乎要密切關(guān)注埋藏在上下文中的長距離 tokens。
LAMBADA。研究者觀察到使用 MTA 訓(xùn)練的模型在正確猜測下一個單詞方面更勝一籌(如表 4),明顯優(yōu)于基線 Transformer 模型。

如表 5 所示,使用 MTA 訓(xùn)練的模型在所有「針數(shù)」和不同上下文長度的撈針能力都有顯著提高。

BabiLong。研究者將重點放在了 QA1-5 任務(wù)上,在這些任務(wù)中,正確的回答需要不同數(shù)量的事實或論據(jù)關(guān)系。輸入和目標(biāo)輸出樣本如表 7 所示。

圖 4(左)展示了平均準(zhǔn)確率,附圖 5 展示了每個任務(wù)的準(zhǔn)確率。與其他模型相比,MTA 模型表現(xiàn)良好,尤其是當(dāng)輸入中有較多干擾文本(4K token)時。


更多實驗結(jié)果請查看原論文。
熱門跟貼