AIxiv專欄是機器之心發(fā)布學(xué)術(shù)、技術(shù)內(nèi)容的欄目。過去數(shù)年,機器之心AIxiv專欄接收報道了2000多篇內(nèi)容,覆蓋全球各大高校與企業(yè)的頂級實驗室,有效促進了學(xué)術(shù)交流與傳播。如果您有優(yōu)秀的工作想要分享,歡迎投稿或者聯(lián)系報道。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
王家豪,香港大學(xué)計算機系二年級博士,導(dǎo)師為羅平教授,研究方向為神經(jīng)網(wǎng)絡(luò)輕量化。碩士畢業(yè)于清華大學(xué)自動化系,已在 NeurIPS、CVPR 等頂級會議上發(fā)表了數(shù)篇論文。
太長不看版:香港大學(xué)聯(lián)合上海人工智能實驗室,華為諾亞方舟實驗室提出高效擴散模型 LiT:探索了擴散模型中極簡線性注意力的架構(gòu)設(shè)計和訓(xùn)練策略。LiT-0.6B 可以在斷網(wǎng)狀態(tài),離線部署在 Windows 筆記本電腦上,遵循用戶指令快速生成 1K 分辨率逼真圖片。

圖 1:LiT 在 Windows 筆記本電腦的離線端側(cè)部署:LiT 可以在端側(cè),斷網(wǎng)狀態(tài),以完全離線的方式遵循用戶指令,快速生成 1K 分辨率圖片
- 論文名稱:LiT: Delving into a Simplified Linear Diffusion Transformer for Image Generation
- 論文地址:https://arxiv.org/pdf/2501.12976v1
- 項目主頁:https://techmonsterwang.github.io/LiT/
為了提高擴散模型的計算效率,一些工作使用 Sub-quadratic 計算復(fù)雜度的模塊來替代二次計算復(fù)雜度的自注意力(Self-attention)機制。這其中,線性注意力的主要特點是:1) 簡潔;2) 并行化程度高。這對于大型語言模型、擴散模型這樣的大尺寸、大計算的模型而言很重要。
就在幾天前,MiniMax 團隊著名的《MiniMax-01: Scaling Foundation Models with Lightning Attention》已經(jīng)在大型語言模型中驗證了線性模型的有效性。而在擴散模型中,關(guān)于「線性注意力要怎么樣設(shè)計,如何訓(xùn)練好基于純線性注意力的擴散模型」的討論仍然不多。
本文針對這個問題,該團隊提出了幾條「拿來即用」的解決方案,向社區(qū)讀者報告了可以如何設(shè)計和訓(xùn)練你的線性擴散 Transformer(linear diffusion Transformers)。列舉如下:
- 使用極簡線性注意力機制足夠擴散模型完成圖像生成。除此之外,線性注意力還有一個「免費午餐」,即:使用更少的頭(head),可以在增加理論 GMACs 的同時 (給模型更多計算),不增加實際的 GPU 延遲。
- 線性擴散 Transformer 強烈建議從一個預(yù)訓(xùn)練好的 Diffusion Transformer 里做權(quán)重繼承。但是,繼承權(quán)重的時候,不要繼承自注意力中的任何權(quán)重(Query, Key, Value, Output 的投影權(quán)重)。
- 可以使用知識蒸餾(Knowledge Distillation)加速訓(xùn)練。但是,在設(shè)計 KD 策略時,我們強烈建議不但蒸餾噪聲預(yù)測結(jié)果,同樣也蒸餾方差預(yù)測結(jié)果 (這一項權(quán)重更小)
LiT 將上述方案匯總成了 5 條指導(dǎo)原則,方便社區(qū)讀者拿來即用。
在標準 ImageNet 基準上,LiT 只使用 DiT 20% 和 23% 的訓(xùn)練迭代數(shù),即可實現(xiàn)相當 FID 結(jié)果。LiT 同樣比肩基于 Mamba 和門控線性注意力的擴散模型。
在文生圖任務(wù)中,LiT-0.6B 可以在斷網(wǎng)狀態(tài),離線部署在 Windows 筆記本電腦上,遵循用戶指令快速生成 1K 分辨率逼真圖片,助力 AIPC 時代降臨。
目錄
1 LiT 研究背景
2 線性注意力計算范式
3 線性擴散 Transformer 的架構(gòu)設(shè)計
4 線性擴散 Transformer 的訓(xùn)練方法
5 圖像生成實驗驗證
6 文生圖實驗驗證
7 離線端側(cè)部署
1 LiT 研究背景
Diffusion Transformer 正在助力文生圖應(yīng)用的商業(yè)化,展示出了極強的商業(yè)價值和潛力。但是,自注意力的二次計算復(fù)雜度也成為了 Diffusion Transformer 的一個老大難問題。因為這對于高分辨率的場景,或者端側(cè)設(shè)備的部署都不算友好。
常見的 Sub-quadratic 計算復(fù)雜度的模塊有 Mamba 的狀態(tài)空間模型(SSM)、門控線性注意力(GLA)、線性注意力等等。目前也有相關(guān)的工作將其用在基于類別的(class-conditional)圖像生成領(lǐng)域 (非文生圖),比如使用了 Mamba 的 DiM、使用了 GLA 的 DiG 。但是,雖然這些工作確實實現(xiàn)了 Sub-quadratic 的計算復(fù)雜度,但是,這些做法也存在明顯的不足:
- 其一,SSM 和 GLA 模塊都依賴遞歸的狀態(tài) (State) 變量,需要序列化迭代計算,對于并行化并不友好。
- 其二,SSM 和 GLA 模塊的計算圖相對于 線性注意力 而言更加復(fù)雜,而且會引入一些算數(shù)強度 (arithmetic-intensity) 比較低的操作,比如逐元素乘法。
而線性注意力相比前兩者,如下圖 2 所示,不但設(shè)計簡單,而且很容易實現(xiàn)并行化。這樣的特點使得線性注意力對于高分辨率極其友好。比如對于 2048px 分辨率圖片,線性注意力比自注意力快約 9 倍,對于 DiT-S/2 生成所需要的 GPU 內(nèi)存也可以從約 14GB 降低到 4GB。因此,訓(xùn)練出一個性能優(yōu)異的基于線性注意力的擴散模型很有價值。

圖 2:與 SSM 和 GLA 相比,線性注意力同樣實現(xiàn) sub-quadratic 的計算復(fù)雜度,同時設(shè)計極其簡潔,且不依賴遞歸的狀態(tài)變量
但是,對于有挑戰(zhàn)性的圖像生成任務(wù),怎么快速,有效地訓(xùn)練好基于線性注意力的擴散模型呢?
這個問題很重要,因為一方面,盡管線性注意力在視覺識別領(lǐng)域已經(jīng)被探索很多,可以取代自注意力,但是在圖像生成中仍然是一個探索不足的問題。另一方面,從頭開始訓(xùn)練擴散模型成本高昂。比如訓(xùn)練 RAPHAEL 需要 60K A100 GPU days ( 中報告)。因此,針對線性擴散 Transformer 的高性價比訓(xùn)練策略仍然值得探索。
LiT 從架構(gòu)設(shè)計和訓(xùn)練策略中系統(tǒng)地研究了純線性注意力的擴散 Transformer 實現(xiàn)。LiT 是一種使用純線性注意力的 Diffusion Transformer。LiT 訓(xùn)練時的成本效率很高,同時在推理過程中保持高分辨率友好屬性,并且可以在 Windows 11 筆記本電腦上離線部署。在基于類別的 ImageNet 256×256 基準上面,100K 訓(xùn)練步數(shù)的 LiT-S/B/L 在 FID 方面優(yōu)于 400K 訓(xùn)練步數(shù)的 DiT-S/B/L。對于 ImageNet 256×256 和 512×512,LiT-XL/2 在訓(xùn)練步驟只有 20% 和 23% 的條件下,實現(xiàn)了與 DiT-XL/2 相當?shù)?FID。在文生圖任務(wù)中,LiT-0.6B 可以在斷網(wǎng)狀態(tài),離線部署在 Windows 筆記本電腦上,遵循用戶指令快速生成 1K 分辨率逼真圖片。
2 線性注意力計算范式

3 線性擴散 Transformer 的架構(gòu)設(shè)計
鑒于對生成任務(wù)上的線性擴散 Transformer 的探索不多,LiT 先以 DiT 為基礎(chǔ),構(gòu)建了一個使用線性注意力的基線模型?;€模型與 DiT 共享相同的宏觀架構(gòu),唯一的區(qū)別是將自注意力替換為 線性注意力。所有實驗均在基于類別的 ImageNet 256×256 基準上進行,使用 256 的 Batch Size 訓(xùn)練了 400K 迭代次數(shù)。
Guideline 1:Simplified 線性注意力對于基于 DiT 的圖像生成擴散模型完全足夠。
我們首先嘗試了在通用視覺基礎(chǔ)模型中成功驗證的常見線性注意力的架構(gòu)設(shè)計,比如 ReLU 線性注意力 (使用 ReLU 激活函數(shù)作為線性注意力的 Kernel Function)。
對于性能參考,將其與 DiT 進行比較,其中任何性能差異都可以歸因于線性注意力對生成質(zhì)量的影響。如圖 4 中所示。與 DiT 相比,使用 ReLU 線性注意力的 LiT-S/2 和 B/2 性能下降很大。結(jié)果表明,視覺識別中常用的線性注意力在噪聲預(yù)測任務(wù)中有改進的空間。
然后我們探索以下方法:
- 簡化型線性注意力 (圖 3,相當于在 ReLU 線性注意力的基礎(chǔ)上加上 Depth-wise 卷積)。
- Focused 線性注意力。
- Focused 線性注意力 (使用 GELU 替換 ReLU)。
這些選擇中的每一個都保持了線性復(fù)雜度,保持了 LiT 在計算效率方面的優(yōu)勢。我們使用相對較大的卷積核 (Kernel Size 5) 來確保在預(yù)測噪聲時足夠大的感受野。

圖 3:在 Simplified 線性注意力中使用更少的 heads

圖 4:不同架構(gòu)的線性注意力消融研究
實驗結(jié)果如圖 4 所示。加了 DWC 的模塊都可以取得大幅的性能提升,我們認為這是因為模型在預(yù)測給定像素的噪聲時關(guān)注相鄰像素信息。同時,我們發(fā)現(xiàn) Focused Function 的有效性有限,我們將其歸因于其設(shè)計動機,以幫助線性注意聚焦于特定區(qū)域。此功能可能適合分類模型,但可能不是噪聲預(yù)測所必需的。為了簡單起見,最后使用簡化 線性注意力。
Guideline 2:在線性注意力中建議使用很少的頭,可以在增加計算的同時不增加時延。
多頭自注意力和線性注意力的計算量分別為:

直覺上似乎使用更多頭可以減少計算壓力。但相反,我們建議使用更少的頭,因為我們觀察到線性注意力存在 Free Lunch 效應(yīng),如圖 5 所示。圖 5 展示了使用線性注意力的 Small,Base,Large,XLarge 模型使用不同頭數(shù)量的延遲和 GMACs 變化。

圖 5:線性注意力中的 Free Lunch 效應(yīng):不同頭數(shù)量線性注意的延遲與理論 GMACs 比較
我們使用 NVIDIA A100 GPU 生成 256×256 分辨率的圖像,批量大小為 8 (NVIDIA V100 GPU 出現(xiàn)類似現(xiàn)象)。結(jié)果表明,減小頭數(shù)量會導(dǎo)致理論 GMACs 穩(wěn)定增加,實際延遲卻并沒有呈現(xiàn)出增加的趨勢,甚至出現(xiàn)下降。我們將這種現(xiàn)象總結(jié)為線性注意力的「免費午餐(Free Lunch)」效應(yīng)。
我們認為在線性注意力中使用更少的頭之后,允許模型有較高的理論計算,根據(jù) scaling law,允許模型在生成性能上達到更高的上限。
實驗結(jié)果如圖 6 所示,對于不同的模型尺度,線性注意力中使用更少的頭數(shù) (比如,2,3,4) 優(yōu)于 DiT 中的默認設(shè)置。相反,使用過多的頭(例如,S/2 的 96 或 B/2 的 192),則會嚴重阻礙生成質(zhì)量。
4 線性擴散 Transformer 的訓(xùn)練方法
LiT 與 DiT 共享一些相同的結(jié)構(gòu),允許權(quán)重繼承自預(yù)訓(xùn)練的 DiT 架構(gòu)。這些權(quán)重包含豐富的與噪聲預(yù)測相關(guān)的知識,有望以成本高效的方式轉(zhuǎn)移到 LiT。因此,在這個部分我們探索把預(yù)訓(xùn)練的 DiT 權(quán)重 (FFN 模塊、adaLN、位置編碼和 Conditional Embedding 相關(guān)的參數(shù)) 繼承給線性 DiT,除了線性注意力部分。

圖 6:線性擴散 Transformer 的權(quán)重繼承策略
Guideline 3:線性擴散 Transformer 的參數(shù)應(yīng)該從一個預(yù)訓(xùn)練到收斂的 DiT 初始化。
我們首先預(yù)訓(xùn)練 DiT-S/2 不同的訓(xùn)練迭代次數(shù):200K、300K、400K、600K 和 800K,并且在每個實驗中,分別將這些預(yù)訓(xùn)練的權(quán)重加載到 LiT-S/2 中,同時線性注意力部分的參數(shù)保持隨機。然后將初始化的 LiT-S/2 在 ImageNet 上訓(xùn)練 400K 迭代次數(shù),結(jié)果如圖 6 所示。
我們觀察到一些有趣的發(fā)現(xiàn):
- DiT 的預(yù)訓(xùn)練權(quán)重,即使只訓(xùn)練了 200K 步,也起著重要作用,將 FID 從 63.24 提高到 57.84。
- 使用預(yù)訓(xùn)練權(quán)重的指數(shù)移動平均 (EMA) 影響很小。
- DiT 訓(xùn)練更收斂時 (800K 步),更適合作為 LiT 的初始化,即使架構(gòu)沒有完全對齊。
我們認為這種現(xiàn)象的一種可能解釋是 Diffusion Transformer 中不同模塊的功能是解耦的。盡管 DiT 和 LiT 具有不同的架構(gòu),但它們的共享組件 (例如 FFN 和 adaLN) 的行為非常相似。因此,可以遷移這些組件預(yù)訓(xùn)練參數(shù)中的知識。同時,即使把 DiT 訓(xùn)練到收斂并遷移共享組件的權(quán)重,也不會阻礙線性注意力部分的優(yōu)化。

圖 7:ImageNet 256×256 上的權(quán)重繼承消融實驗結(jié)果
Guideline 4:線性注意力中的 Query、Key、Value 和 Output 投影矩陣參數(shù)應(yīng)該隨機初始化,不要繼承自自注意力。
在 LiT 中,線性注意力中的一些權(quán)重與 DiT 的自注意力中的權(quán)重重疊,包括 Query、Key、Value 和 Output 投影矩陣。盡管計算范式存在差異,但這些權(quán)重可以直接從 DiT 加載到 LiT 中,而不需要從頭訓(xùn)練。但是,這是否可以加速其收斂性仍然是一個懸而未決的問題。
我們使用經(jīng)過 600K 次迭代預(yù)訓(xùn)練的 DiT-S/2 進行消融實驗。探索了 5 種不同類型的加載策略,包括:
- 加載 Query,Key 和 Value 投影矩陣。
- 加載 Key 和 Value 投影矩陣。
- 加載 Value 投影矩陣。
- 加載 Query 投影矩陣。
- 加載 Output 投影矩陣。
結(jié)果如圖 7 所示。與沒有加載自注意力權(quán)重的基線相比,沒有一個探索的策略顯示出更好的生成性能。這種現(xiàn)象可歸因于計算范式的差異。具體來說,線性注意力直接計算鍵和值矩陣的乘積,但是自注意力就不是這樣的。因此,自注意力中的 Key 和 Value 相關(guān)的權(quán)重對線性注意力的好處有限。
我們建議繼承除線性注意力之外的所有預(yù)訓(xùn)練參數(shù)從預(yù)訓(xùn)練好的 DiT 中,因為它易于實現(xiàn)并且非常適合基于 Transformer 架構(gòu)的擴散模型。

圖 8:混合知識蒸餾訓(xùn)練線性擴散 Transformer
Guideline 5:使用混合知識蒸餾訓(xùn)練線性擴散 Transformer 很關(guān)鍵,不僅蒸餾噪聲預(yù)測結(jié)果,還蒸餾方差的預(yù)測結(jié)果。
知識蒸餾通常采用教師網(wǎng)絡(luò)來幫助訓(xùn)練輕量級學(xué)生網(wǎng)絡(luò)。對于擴散模型,蒸餾通常側(cè)重于減少目標模型的采樣步驟。相比之下,我們專注于在保持采樣步驟的前提下,從復(fù)雜的模型蒸餾出更簡單的模型。



圖 9:ImageNet 256×256 上的知識蒸餾實驗結(jié)果,帶有下劃線的結(jié)果表示不使用知識蒸餾
到目前為止,LiT 遵循 DiT 的宏觀 / 微觀設(shè)計,但采用了高效的線性注意力。使用我們的訓(xùn)練策略,LiT-S/2 顯著地提高了 FID。接下來,我們在更大的變體 (例如 B/L/XL) 和具有挑戰(zhàn)性的任務(wù) (比如 T2I) 上驗證它。
5 圖像生成實驗驗證
ImageNet 256×256 基準
我們首先在 ImageNet 256×256 基準上驗證 LiT。LiT-S/2、B/2、L/2、XL/2 配置與 DiT 一致,只是線性注意力的頭分別設(shè)置為 2/3/4/4。對于所有模型變體,DWC Kernel Size 都設(shè)置為 5。我們以 256 的 Batch Size 訓(xùn)練 400K 步。對于 LiT-XL/2,將訓(xùn)練步數(shù)擴展到 1.4M 步 (只有 DiT-XL/2 7M 的 20%)。我們使用預(yù)訓(xùn)練的 DiT 初始化 LiT 的參數(shù)。Lambda_1 和 lambda_2 在混合知識蒸餾中設(shè)置為 0.5 和 0.05。
圖 10 和 11 比較了 LiT 和 DiT 的不同尺寸模型的結(jié)果。值得注意的是,僅 100K 訓(xùn)練迭代次數(shù)訓(xùn)練的 LiT 已經(jīng)在各種評估指標和不同尺寸的模型中優(yōu)于 400K 訓(xùn)練迭代次數(shù)訓(xùn)練的 DiT。使用 400K 訓(xùn)練迭代次數(shù)的額外訓(xùn)練,模型的性能繼續(xù)提高。盡管訓(xùn)練步驟只有 DiT-XL/2 的 20%,但 LiT-XL/2 仍然取得與 DiT 相當?shù)?FID 結(jié)果 (2.32 對 2.27)。此外,LiT 與基于 U-Net 的基線性能相當。這些結(jié)果表明,當線性注意力結(jié)合合適的優(yōu)化策略時,可以可靠地用于圖像生成應(yīng)用。

圖 10:ImageNet 256×256 基準實驗結(jié)果,與基于自注意力的 DiT 和基于門控線性注意力的 DiG 的比較

圖 11:ImageNet 256×256 基準實驗結(jié)果
ImageNet 512×512 基準
我們繼續(xù)在 ImageNet 512×512 基準上進一步驗證了 LiT-XL/2。使用預(yù)訓(xùn)練的 DiT-XL/2 作為教師模型,使用其權(quán)重初始化 LiT-XL/2。對于知識蒸餾,分別設(shè)置 Lambda_1 和 lambda_2 為 1.0 和 0.05,并且只訓(xùn)練 LiT-XL/2 700K 訓(xùn)練迭代次數(shù) (是 DiT 3M 訓(xùn)練迭代次數(shù)的 23%)。
值得注意的是,與使用 256 的 Batch Size 的 DiT 不同,我們采用 128 的較小 Batch Size。這其實不占便宜,因為 128 的 Batch Size 相比 256 的情況,完成 1 Epoch 需要 2 倍的訓(xùn)練迭代次數(shù)。也就是說,我們 700K 的訓(xùn)練迭代次數(shù)其實只等效為 256 Batch Size 下的 350K。盡管如此,使用純線性注意力的 LiT 實現(xiàn)了 3.69 的 FID,與 3M 步訓(xùn)練的 DiT 相當,將訓(xùn)練步驟減少了約 77%。此外,LiT 優(yōu)于幾個強大的 Baseline。這些結(jié)果證明了我們提出的成本高效的訓(xùn)練策略在高分辨率數(shù)據(jù)集上的有效性。實驗結(jié)果如圖 12 所示。

圖 12:ImageNet 512×512 基準實驗結(jié)果
6 文生圖實驗驗證
文生圖對于擴散模型的商業(yè)應(yīng)用極為重要。LiT 遵循 PixArt-α 的做法,將交叉注意力添加到 LiT-XL/2 中使其支持文本嵌入。LiT 將線性注意力的頭數(shù)設(shè)置為 2,DWC Kernel Size 設(shè)置為 5。遵循 PixArt-Σ 的做法,使用預(yù)訓(xùn)練的 SDXL VAE Encoder 和 T5 編碼器 (即 Flan-T5-XXL) 分別提取圖像和文本特征。
LiT 使用 PixArt-Σ 作為教師來監(jiān)督其訓(xùn)練,分別設(shè)置 Lambda_1 和 lambda_2 為 1.0 和 0.05。LiT 從 PixArt-Σ 繼承權(quán)重,除了自注意力的參數(shù)。隨后在內(nèi)部數(shù)據(jù)集上訓(xùn)練,學(xué)習(xí)率為 2e-5,僅訓(xùn)練 45400 步,明顯低于 PixArt-α 的多階段訓(xùn)練。圖 13 為 LiT 生成的 512px 圖像采樣結(jié)果。盡管在每個 Block 中都使用了線性注意力,以及我們的成本高效的訓(xùn)練策略,LiT 仍然可以產(chǎn)生異常逼真的圖像。

圖 13:LiT 根據(jù)用戶指令生成的 512px 圖片
我們還將分辨率進一步增加到 1K。更多的實驗細節(jié)請參閱原論文。圖 14 是生成的結(jié)果采樣。盡管用廉價的線性注意力替換所有自注意力,但 LiT 仍然能夠以高分辨率生成逼真的圖像。

圖 14:LiT 根據(jù)用戶指令生成的 1K 分辨率圖片
7 離線端側(cè)部署
我們還將 1K 分辨率的 LiT-XL/2 模型部署到一臺 Windows 11 操作系統(tǒng)驅(qū)動的筆記本電腦上,以驗證其 On-device 的能力??紤]到筆記本電腦的 GPU 內(nèi)存的限制,我們將文本編碼器量化為 8-bit,同時在線性注意力計算期間保持 fp16 精度。圖 1 顯示了我們的部署結(jié)果。預(yù)訓(xùn)練的 LiT 可以在離線設(shè)置 (沒有網(wǎng)絡(luò)連接) 的情況下快速生成照片逼真的 1K 分辨率圖像。這些結(jié)果說明 LiT 作為一種 On-device 的擴散模型的成功實現(xiàn),推進邊緣設(shè)備上的高分辨率文生圖任務(wù)。
下面提供了一個視頻 Demo:
https://mp.weixin.qq.com/s/XEQJnt5cJ63spqSG67WLGw?token=1784997338&lang=zh_CN
展示了在斷網(wǎng)狀態(tài)下離線使用 LiT 完成 1K 分辨率文生圖任務(wù)的過程。
熱門跟貼