Object-Centric Learning with Slot Attention

基于Slot Attention的對象中心學習

https://papers.neurips.cc/paper_files/paper/2020/file/8511df98c02ab60aea1b2356c013bc0f-Paper.pdf

打開網(wǎng)易新聞 查看精彩圖片

摘要

學習復雜場景的以對象為中心的表示是實現(xiàn)從低級感知特征中進行高效抽象推理的有前途的一步。然而,大多數(shù)深度學習方法學習到的是分布式的表示,這些表示無法捕捉自然場景的組合性質(zhì)。在本文中,我們提出了槽注意力(Slot Attention)模塊,這是一個架構(gòu)組件,它與感知表示(例如卷積神經(jīng)網(wǎng)絡的輸出)進行交互,并生成一組任務相關(guān)的抽象表示,我們稱之為槽。這些槽是可互換的,可以通過多輪注意力的競爭過程來專門化,從而綁定到輸入中的任何對象。我們通過實證研究證明,槽注意力能夠提取以對象為中心的表示,當在無監(jiān)督的對象發(fā)現(xiàn)任務和有監(jiān)督的屬性預測任務上進行訓練時,能夠?qū)崿F(xiàn)對未見組合的泛化。

1 引言

以對象為中心的表示有潛力提高機器學習算法在多個應用領(lǐng)域的樣本效率和泛化能力,例如視覺推理[1]、結(jié)構(gòu)化環(huán)境建模[2]、多智能體建模[3-5]以及交互式物理系統(tǒng)的模擬[6-8]。從原始感知輸入(如圖像或視頻)中獲得以對象為中心的表示是具有挑戰(zhàn)性的,通常需要監(jiān)督[1, 3, 9, 10]或特定于任務的架構(gòu)[2, 11]。因此,學習以對象為中心的表示這一步驟通常被完全跳過。相反,模型通常被訓練為在從模擬器的內(nèi)部表示[6, 8]或游戲引擎[4, 5]獲得的結(jié)構(gòu)化環(huán)境表示上運行。

為了克服這一挑戰(zhàn),我們引入了槽注意力模塊,這是一個可微分的接口,介于感知表示(例如CNN的輸出)和一組稱為槽的變量之間。通過迭代注意力機制,槽注意力生成一組具有排列對稱性的輸出向量。與膠囊網(wǎng)絡[12, 13]中使用的膠囊不同,槽注意力生成的槽不會專門化為某一特定類型或類別的對象,這可能會損害泛化能力。相反,它們類似于對象文件[14],即槽使用一種通用的表示格式:每個槽都可以存儲(并綁定到)輸入中的任何對象。這使得槽注意力能夠以系統(tǒng)的方式泛化到未見的組合、更多的對象和更多的槽。

槽注意力是一個簡單且易于實現(xiàn)的架構(gòu)組件,可以放置在例如CNN[15]編碼器的頂部,從圖像中提取對象表示,并與下游任務一起端到端地進行訓練。在本文中,我們將圖像重建和集合預測作為下游任務,以展示我們模塊在具有挑戰(zhàn)性的無監(jiān)督對象發(fā)現(xiàn)設置以及涉及集合結(jié)構(gòu)對象屬性預測的有監(jiān)督任務中的多功能性。

我們的主要貢獻如下:

(i) 我們引入了槽注意力(Slot Attention)模塊,這是一個簡單的架構(gòu)組件,位于感知表示(例如卷積神經(jīng)網(wǎng)絡的輸出)與以集合形式結(jié)構(gòu)化的表示之間的接口。

(ii) 我們將基于槽注意力的架構(gòu)應用于無監(jiān)督的對象發(fā)現(xiàn)任務中,在該任務中,它與相關(guān)最先進的方法[16, 17]相匹配或優(yōu)于這些方法,同時具有更高的內(nèi)存效率,并且訓練速度顯著更快。

(iii) 我們證明了槽注意力模塊可以用于有監(jiān)督的對象屬性預測任務,在這些任務中,注意力機制能夠?qū)W會突出顯示單個對象,而無需直接的對象分割監(jiān)督。

2 方法
在本節(jié)中,我們介紹槽注意力(Slot Attention)模塊(圖1a;第2.1節(jié)),并展示如何將其集成到用于無監(jiān)督對象發(fā)現(xiàn)的架構(gòu)中(圖1b;第2.2節(jié))以及集成到集合預測架構(gòu)中(圖1c;第2.3節(jié))。

2.1 槽注意力模塊

打開網(wǎng)易新聞 查看精彩圖片
打開網(wǎng)易新聞 查看精彩圖片
打開網(wǎng)易新聞 查看精彩圖片
打開網(wǎng)易新聞 查看精彩圖片
打開網(wǎng)易新聞 查看精彩圖片

證明在補充材料中。排列等變性屬性對于確保槽學習一種通用的表示格式以及每個槽都可以綁定到輸入中的任何對象是至關(guān)重要的。

2.2 對象發(fā)現(xiàn)

以集合形式結(jié)構(gòu)化的隱藏表示是無監(jiān)督學習場景中對象的一種有吸引力的選擇:每個集合元素可以捕捉場景中一個對象的屬性,而無需假設對象被描述的特定順序。由于槽注意力將輸入表示轉(zhuǎn)換為一組向量,因此它可以作為自編碼器架構(gòu)中編碼器的一部分,用于無監(jiān)督的對象發(fā)現(xiàn)。自編碼器的任務是將圖像編碼為一組隱藏表示(即槽),這些槽可以被解碼回圖像空間以重建原始輸入。因此,槽作為表示瓶頸,解碼器的架構(gòu)(或解碼過程)通常被設計為每個槽只解碼圖像的一個區(qū)域或部分[16, 17, 24–27]。這些區(qū)域/部分隨后被組合以得到完整的重建圖像。

編碼器 我們的編碼器由兩個部分組成:(i) 帶有位置嵌入的CNN主干網(wǎng)絡,隨后是 (ii) 一個槽注意力模塊。槽注意力的輸出是一組槽,這些槽表示場景的一種分組(例如,以對象為單位)。

解碼器 每個槽都通過空間廣播解碼器[28]單獨解碼,這與IODINE[16]中使用的方法相同:槽表示被廣播到一個二維網(wǎng)格(每個槽一個)上,并附加位置嵌入。每個這樣的網(wǎng)格使用CNN(參數(shù)在槽之間共享)進行解碼,以產(chǎn)生大小為 的輸出,其中 W 和 H 分別是圖像的寬度和高度。輸出通道編碼RGB顏色通道和一個(未歸一化的)alpha蒙版。隨后,我們使用Softmax對槽的alpha蒙版進行歸一化,并將它們用作混合權(quán)重,將各個重建結(jié)果組合成一個單一的RGB圖像。

2.3 集合預測

集合表示在許多數(shù)據(jù)模態(tài)的任務中被廣泛應用,范圍從點云預測[29, 30]、圖像中多個對象的分類[31],到生成具有期望屬性的分子[32, 33]。在本文考慮的例子中,我們給定一個輸入圖像和一組預測目標,每個目標描述場景中的一個對象。預測集合的關(guān)鍵挑戰(zhàn)在于,對于包含 K 個元素的集合,存在 K! 種可能的等價表示,因為目標的順序是任意的。這種歸納偏差需要在架構(gòu)中明確建模,以避免在學習過程中出現(xiàn)不連續(xù)性,例如在訓練過程中兩個語義專門化的槽交換它們的內(nèi)容時[31, 34]。槽注意力的輸出順序是隨機的,并且與輸入順序無關(guān),這解決了這一問題。因此,槽注意力可以將輸入場景的分布式表示轉(zhuǎn)換為集合表示,在這種表示中,每個對象都可以使用標準分類器分別進行分類,如圖1c所示。

編碼器 我們使用與對象發(fā)現(xiàn)設置(第2.2節(jié))中相同的編碼器架構(gòu),即帶有位置嵌入的CNN主干網(wǎng)絡,隨后是槽注意力,以得到一組槽表示。

分類器 對于每個槽,我們應用一個MLP,其參數(shù)在槽之間共享。由于預測和標簽的順序都是任意的,我們使用匈牙利算法[35]將它們匹配。我們將探索其他匹配算法[36, 37]的工作留給未來研究。

3 相關(guān)工作

對象發(fā)現(xiàn) 我們的對象發(fā)現(xiàn)架構(gòu)與最近一系列關(guān)于組合生成場景模型的工作密切相關(guān),這些模型將場景表示為具有相同表示格式的潛在變量集合[16, 17, 24–27, 38–44]。與我們的方法最接近的是IODINE [16]模型,該模型使用迭代變分推斷[45]來推斷一組潛在變量,每個變量描述圖像中的一個對象。在每次推斷迭代中,IODINE執(zhí)行一個解碼步驟,隨后在像素空間進行比較以及后續(xù)的編碼步驟。與之相關(guān)的模型,如MONet [17]和GENESIS [27],也類似地使用多個編碼-解碼步驟。而我們的模型則用迭代注意力的單次編碼步驟取代了這一過程,從而提高了計算效率。此外,這使得我們的架構(gòu)即使在沒有解碼器的情況下,也能夠推斷對象表示和注意力掩碼,為超越自編碼的擴展提供了可能,例如用于對象發(fā)現(xiàn)的對比表示學習[46],或者直接優(yōu)化下游任務,如控制或規(guī)劃。我們的基于注意力的路由過程也可以與使用基于塊的解碼器的架構(gòu)(如AIR [26]、SQAIR [40]以及相關(guān)方法[41–44])結(jié)合使用,作為通常采用的自回歸編碼器[26, 40]的替代方案。我們的方法與使用對抗訓練[47–49]或?qū)Ρ葘W習[46]進行對象發(fā)現(xiàn)的方法是正交的:在這些設置中使用Slot Attention是一個未來工作的一個有趣方向。

集合的神經(jīng)網(wǎng)絡 近期一系列方法探索了集合編碼[34, 50, 51]、生成[31, 52]以及集合到集合的映射[20, 53]。圖神經(jīng)網(wǎng)絡[54–57],尤其是Transformer模型[20]中的自注意力機制,常用于處理具有固定基數(shù)(即集合元素數(shù)量)的元素集合。Slot Attention解決了從一個集合映射到另一個不同基數(shù)的集合的問題,同時尊重輸入和輸出集合的排列對稱性。Deep Set Prediction Network(DSPN)[31, 58]通過為每個樣本運行一個內(nèi)部梯度下降循環(huán)來尊重排列對稱性,但這需要許多步才能收斂,并且需要仔細調(diào)整多個損失超參數(shù)。相比之下,Slot Attention僅通過幾次注意力迭代和一個特定于任務的損失函數(shù)直接從集合映射到集合。在同期工作中,DETR[59]和TSPN[60]模型提出使用Transformer[20]進行條件集合生成。大多數(shù)相關(guān)方法,包括DiffPool[61]、Set Transformers[53]、DSPN[31]和DETR[59],都使用了針對每個元素的學習初始化(即為每個集合元素設置單獨的參數(shù)),這使得這些方法無法在測試時推廣到更多集合元素。

迭代路由 我們的迭代注意力機制與通常用于膠囊網(wǎng)絡變體[12, 13, 62]中的迭代路由機制有相似之處。最接近的變體是倒置點積注意力路由[62],它同樣使用點積注意力機制來獲得表示之間的分配系數(shù)。然而,他們的方法(與其他膠囊模型一致)沒有排列對稱性,因為每個輸入-輸出對都被分配了一個單獨參數(shù)化的變換。兩種方法在注意力機制的歸一化細節(jié)、更新的聚合方式以及考慮的應用方面存在顯著差異。

交互式記憶模型 Slot Attention可以被視為交互式記憶模型[9, 39, 46, 63–68]的一個變體,這些模型利用一組槽位及其成對交互來推理輸入中的元素(例如視頻中的對象)。這些模型的共同組成部分包括:(i) 獨立作用于各個槽位的遞歸更新函數(shù),以及(ii) 引入槽位之間通信的交互函數(shù)。通常,這些模型中的槽位是完全對稱的,所有槽位共享相同的遞歸更新函數(shù)和交互函數(shù),唯一的例外是RIM模型[67],它為每個槽位使用單獨的一組參數(shù)。值得注意的是,RMC[63]和RIM[67]引入了注意力機制,用于將輸入信息聚合到槽位中。在Slot Attention中,輸入到槽位的基于注意力的分配是針對槽位進行歸一化的(而不僅僅是針對輸入),這引入了槽位之間的競爭,以對輸入進行聚類。此外,我們在本工作中不考慮時間序列數(shù)據(jù),而是使用遞歸更新函數(shù)來迭代細化對單個靜態(tài)輸入的預測。

專家混合模型 專家模型[67, 69–72]與我們的基于槽位的方法相關(guān),但個體專家之間并不完全共享參數(shù)。這導致個體專家對不同的任務或?qū)ο箢愋瓦M行專業(yè)化。在Slot Attention中,槽位使用共同的表示格式,每個槽位都可以綁定到輸入的任何部分。

聚類 我們的路由過程與軟k均值聚類[73](其中槽位對應于聚類中心)相關(guān),但有兩個關(guān)鍵區(qū)別:我們使用帶有學習線性投影的點積相似性,并且我們使用參數(shù)化、可學習的更新函數(shù)。計算機視覺[74]和語音識別領(lǐng)域[75]引入了具有可學習、聚類特定參數(shù)的軟k均值聚類變體,但它們與我們的方法不同,因為它們不使用遞歸的多步更新,也不尊重排列對稱性(聚類中心在訓練后充當固定的有序字典)。Set Transformer[53]的誘導點機制和DETR[59]中的圖像到槽位的注意力機制可以被視為這些有序、單步方法的擴展,它們?yōu)槊總€聚類分配使用多個注意力頭(即多個相似性函數(shù))。

循環(huán)注意力 我們的方法與用于圖像建模和場景分解[26, 40, 76–78]以及集合預測[79]的循環(huán)注意力模型相關(guān)。在這一背景下,也考慮了不使用注意力機制的集合預測循環(huán)模型[80, 81]。這一系列工作經(jīng)常使用排列不變的損失函數(shù)[79, 80, 82],但依賴于以自回歸的方式在每個時間步推斷一個槽位、表示或標簽,而Slot Attention在每一步同時更新所有槽位,因此完全尊重排列對稱性。

4 實驗

本節(jié)的目標是在兩個以對象為中心的任務上評估Slot Attention模塊,其中一個是有監(jiān)督的,另一個是無監(jiān)督的,分別如第2.2節(jié)和第2.3節(jié)所述。我們將與每個任務的專門的最先進方法[16, 17, 31]進行比較。我們會在補充材料中提供關(guān)于實驗和實現(xiàn)的更多細節(jié),以及額外的定性結(jié)果和消融研究。

基線 在無監(jiān)督的對象發(fā)現(xiàn)實驗中,我們與兩個最近的最先進模型進行比較:IODINE[16]和MONet[17]。對于有監(jiān)督的對象屬性預測,我們與Deep Set Prediction Networks(DSPN)[31]進行比較。據(jù)我們所知,DSPN是唯一尊重排列對稱性的集合預測模型,除了我們提出的模型之外。在這兩個任務中,我們還與一個簡單的基于MLP的基線進行比較,我們將其稱為Slot MLP。該模型用一個MLP替換Slot Attention,將CNN特征圖(調(diào)整大小并展平)映射到(現(xiàn)在是有序的)槽表示。對于MONet、IODINE和DSPN基線,我們與[16, 31]中發(fā)布的數(shù)字進行比較,因為我們使用了相同的實驗設置。

數(shù)據(jù)集 在對象發(fā)現(xiàn)實驗中,我們使用以下三個多對象數(shù)據(jù)集[83]:CLEVR(帶掩碼)、Multi-dSprites和Tetrominoes。CLEVR(帶掩碼)是帶有分割掩碼注釋的CLEVR數(shù)據(jù)集的版本。與IODINE[16]類似,我們僅使用CLEVR(帶掩碼)數(shù)據(jù)集的前70K樣本進行訓練,并裁剪圖像以突出顯示中心的對象。對于Multi-dSprites和Tetrominoes,我們使用前60K樣本。正如在[16]中所做的那樣,我們在320個測試樣本上評估對象發(fā)現(xiàn)。對于集合預測,我們使用原始的CLEVR數(shù)據(jù)集[84],其中包含分別用于訓練和驗證的70K和15K張渲染對象的圖像。每張圖像可以包含三到十個對象,并且每個對象都有屬性注釋(位置、形狀、材質(zhì)、顏色和大?。?。在某些實驗中,我們將CLEVR數(shù)據(jù)集過濾為僅包含最多6個對象的場景;我們將這個數(shù)據(jù)集稱為CLEVR6,并為了清晰起見,將原始完整數(shù)據(jù)集稱為CLEVR10。

4.1 對象發(fā)現(xiàn)

訓練 訓練設置是無監(jiān)督的:學習信號由(均方)圖像重建誤差提供。我們使用Adam優(yōu)化器[85]進行訓練,學習率為4×10??,批量大小為64(使用單個GPU)。此外,我們還使用學習率預熱[86]以防止注意力機制過早飽和,并使用指數(shù)衰減的學習率調(diào)度,我們發(fā)現(xiàn)這可以減少方差。在訓練時,我們使用T = 3次Slot Attention迭代。我們在所有數(shù)據(jù)集上使用相同的訓練設置,除了槽的數(shù)量K:對于CLEVR6,我們使用K = 7個槽;對于Multi-dSprites(每個場景最多5個對象),我們使用K = 6個槽;對于Tetrominoes(每個場景3個對象),我們使用K = 4個槽。盡管Slot Attention中的槽數(shù)量可以為每個輸入樣本設置不同的值,但我們在訓練集中對所有樣本使用相同的值K,以便于批量處理。

評估指標 與以往的工作[16, 17]一致,我們使用調(diào)整后的蘭德指數(shù)(Adjusted Rand Index, ARI)分數(shù)[87, 88]比較解碼器產(chǎn)生的alpha掩碼(針對每個單獨的對象槽)與真實分割(不包括背景)。ARI是一個用于衡量聚類相似性的分數(shù),范圍從0(隨機)到1(完美匹配)。為了計算ARI分數(shù),我們使用了Kabra等人[83]提供的實現(xiàn)。

結(jié)果 定量結(jié)果總結(jié)于表1和圖2??傮w而言,我們觀察到我們的模型與兩個最近的最先進基線模型相比具有優(yōu)勢:IODINE[16]和MONet[17]。我們還與一個簡單的基于MLP的基線(Slot MLP)進行了比較,該基線的表現(xiàn)優(yōu)于隨機水平,但由于其有序的表示方式,無法模擬該任務的組合性質(zhì)。我們注意到我們模型的一個失敗模式:在極少數(shù)情況下,它可能會在Tetrominoes數(shù)據(jù)集上陷入次優(yōu)解,其中它將圖像分割成條紋。這導致訓練集上的重建誤差顯著更高,因此這種異常值可以在訓練時輕松識別。我們在表1的最終得分中排除了一個這樣的異常值(5個種子中的1個)。我們預計,仔細調(diào)整訓練超參數(shù),特別是針對這個數(shù)據(jù)集,可能會緩解這一問題,但為了簡單起見,我們選擇在所有數(shù)據(jù)集上使用單一設置。

打開網(wǎng)易新聞 查看精彩圖片

與IODINE[16]相比,Slot Attention在內(nèi)存消耗和運行時間方面顯著更高效。在CLEVR6上,我們可以在單個具有16GB RAM的V100 GPU上使用高達64的批量大小,而[16]中使用相同硬件類型的批量大小為4。同樣,當使用8個V100 GPU并行運行時,Slot Attention在CLEVR6上的模型訓練大約需要24小時,而IODINE[16]則大約需要7天。

在圖2中,我們研究了在測試時使用更多的Slot Attention迭代次數(shù)時,我們的模型能夠泛化到什么程度,而訓練時使用的是固定的T = 3次迭代。我們進一步評估了與訓練集(CLEVR6)相比,對更多對象(CLEVR10)的泛化能力。我們觀察到,當使用更多迭代時,分割分數(shù)顯著高于表1中報告的數(shù)字。當在包含更多對象的CLEVR10場景上進行測試時,這種改進更為顯著。在本實驗中,我們將槽的數(shù)量從訓練時的K = 7增加到測試時的K = 11。總體而言,即使在測試包含比訓練期間看到的更多對象的場景時,分割性能仍然很強。

我們在圖3中為所有三個數(shù)據(jù)集可視化了發(fā)現(xiàn)的對象分割結(jié)果。如果槽的數(shù)量多于對象的數(shù)量,模型會學會讓某些槽保持為空(僅捕獲背景)。我們發(fā)現(xiàn),Slot Attention通常會將均勻的背景分散到所有槽中,而不是僅在一個槽中捕獲背景,這可能是注意力機制的一個副作用,但并不會損害對象解耦或重建質(zhì)量。我們進一步可視化了注意力機制在各個注意力迭代過程中如何分割場景,并檢查了每次單獨迭代后的場景重建結(jié)果(模型僅在最終迭代后被訓練用于重建)。可以看到,注意力機制在第二次迭代時已經(jīng)學會專注于提取單個對象,而第一次迭代的注意力圖仍然將多個對象的部分映射到同一個槽中。

打開網(wǎng)易新聞 查看精彩圖片

為了評估Slot Attention是否能夠在不依賴顏色線索的情況下執(zhí)行分割,我們還在具有白色對象和黑色背景的二值化multi-dSprites數(shù)據(jù)集以及灰度版CLEVR6數(shù)據(jù)集上進行了實驗。我們使用了Kabra等人[83]提供的二值化multi-dSprites數(shù)據(jù)集,Slot Attention在該數(shù)據(jù)集上使用K = 4個槽時達到了69.4 ± 0.9%的ARI分數(shù),相比之下,IODINE[16]的分數(shù)為64.8 ± 17.2%,R-NEM[39]的分數(shù)為68.5 ± 1.7%,如[16]中所報告。Slot Attention在僅基于形狀線索分解場景為對象方面表現(xiàn)出競爭力。我們在圖4中可視化了在灰度版CLEVR6上訓練的Slot Attention模型發(fā)現(xiàn)的對象分割結(jié)果,盡管缺乏顏色作為區(qū)分對象的特征,Slot Attention仍然能夠很好地處理。

由于我們的對象發(fā)現(xiàn)架構(gòu)使用了與IODINE[16]相同的解碼器和重建損失,因此我們預計它在處理包含更復雜背景和紋理的場景時也會遇到類似的困難。使用不同的感知損失[49, 89]或?qū)Ρ葥p失[46]可能有助于克服這一限制。我們將在第5節(jié)和補充材料中進一步討論限制和未來工作。

總結(jié) Slot Attention在無監(jiān)督場景分解方面與以往的方法高度競爭,無論是在對象分割的質(zhì)量上,還是在訓練速度和內(nèi)存效率上。在測試時,Slot Attention可以在沒有解碼器的情況下使用,以從未見的場景中獲得以對象為中心的表示。

4.2 集合預測

訓練 我們使用與第4.1節(jié)相同的超參數(shù)來訓練模型,除了我們使用了512的批量大小,并在編碼器中使用了步幅(striding)。在CLEVR10上,我們使用K = 10個對象槽,以與[31]保持一致。Slot Attention模型使用單個具有16GB RAM的NVIDIA Tesla V100 GPU進行訓練。

評估指標 根據(jù)Zhang等人的方法,我們計算了平均精度(Average Precision, AP),這是目標檢測中常用的指標。如果預測的對象屬性(形狀、材質(zhì)、顏色和大?。┡c真實對象完全一致,并且在一定的距離閾值內(nèi)(∞表示不設置閾值),則認為預測是正確的。預測的位置坐標被縮放到[?3, 3]范圍內(nèi)。我們將目標進行零填充,并預測一個額外的指示分數(shù),范圍在[0, 1]之間,表示對象存在的概率(1表示存在對象),然后將其用作預測置信度來計算AP。

打開網(wǎng)易新聞 查看精彩圖片

結(jié)果 在圖5(左)中,我們報告了在CLEVR10上進行有監(jiān)督對象屬性預測的平均精度結(jié)果(在訓練和測試時,Slot Attention均使用T = 3)。我們與[31]中的DSPN結(jié)果以及Slot MLP基線進行了比較。總體而言,我們的方法與DSPN基線相當或優(yōu)于DSPN基線。在更具挑戰(zhàn)性的距離閾值下(針對對象位置特征),我們的方法性能下降較為平緩,并且方差較小。需要注意的是,DSPN基線[31]使用了深度顯著更深的ResNet 34作為圖像編碼器。在圖5(中)中,我們觀察到在測試時增加注意力迭代次數(shù)通常可以提高性能。Slot Attention可以通過改變槽的數(shù)量自然地處理測試時更多的對象。在圖5(右)中,我們觀察到如果在CLEVR6(K = 6個槽)上訓練模型,并在測試時使用更多對象,AP會平緩下降。直觀上,為了解決這個集合預測任務,每個槽應該關(guān)注不同的對象。在圖6中,我們可視化了兩個CLEVR圖像中每個槽的注意力圖??傮w而言,我們觀察到注意力圖自然地分割了對象。我們指出,該方法僅被訓練用于預測對象的屬性,而沒有使用任何分割掩碼。從定量角度看,我們可以評估注意力掩碼的調(diào)整蘭德指數(shù)(Adjusted Rand Index, ARI)分數(shù)。在帶有掩碼的CLEVR10上,Slot Attention生成的注意力掩碼達到了78.0% ± 2.9的ARI分數(shù)(為了計算ARI,我們將輸入圖像縮放到32 × 32)。需要注意的是,表1中評估的掩碼并非注意力圖,而是由對象發(fā)現(xiàn)解碼器預測的。

總結(jié) Slot Attention學習了用于集合結(jié)構(gòu)屬性預測任務的對象表示,并取得了與以往最先進方法相當?shù)慕Y(jié)果,同時在實現(xiàn)和調(diào)整方面要簡單得多。此外,注意力掩碼自然地分割了場景,這對于調(diào)試和解釋模型的預測結(jié)果非常有價值。

5 結(jié)論

我們提出了Slot Attention模塊,這是一個多功能的架構(gòu)組件,能夠從未加工的感知輸入中學習以對象為中心的抽象表示。Slot Attention中使用的迭代注意力機制使我們的模型能夠?qū)W習一種分組策略,將輸入特征分解為一組槽表示。在無監(jiān)督視覺場景分解和有監(jiān)督對象屬性預測的實驗中,我們已經(jīng)證明Slot Attention與以往相關(guān)方法高度競爭,同時在內(nèi)存消耗和計算方面更加高效。

一個自然的下一步是將Slot Attention應用于視頻數(shù)據(jù)或其他數(shù)據(jù)模態(tài),例如用于圖中的節(jié)點聚類、基于點云處理的背景、文本或語音數(shù)據(jù)。研究其他下游任務,如獎勵預測、視覺推理、控制或規(guī)劃,也是很有前景的。

廣義影響 Slot Attention模塊能夠從未加工的感知輸入中學習以對象為中心的表示。因此,它是一個可以在廣泛領(lǐng)域和應用中使用的通用模塊。在我們的論文中,我們僅考慮了在嚴格控制設置下人工生成的數(shù)據(jù)集,這些數(shù)據(jù)集中槽位被期望專門化為對象。然而,我們模型的專門化是隱式的,并完全由下游任務驅(qū)動。我們指出,作為一種具體措施,以評估模塊是否以不希望的方式專門化,可以可視化注意力掩碼,以了解輸入特征是如何分布在各個槽中的(見圖6)。盡管需要更多的工作來適當?shù)亟鉀Q注意力系數(shù)在解釋網(wǎng)絡整體預測中的有用性(特別是當輸入特征對人類不可解釋時),但我們認為它們可能是實現(xiàn)更透明和可解釋預測的一步。

原文鏈接: https://papers.neurips.cc/paper_files/paper/2020/file/8511df98c02ab60aea1b2356c013bc0f-Paper.pdf