0
作者 | Marin Vlastelica
編譯 | 蔣寶尚
(雷鋒網(wǎng)作品)目前,在計算機(jī)這個學(xué)科中有兩個非常重要方向:一個是離散優(yōu)化的經(jīng)典算法-圖算法,例如SAT求解器、整數(shù)規(guī)劃求解器;另一個是近幾年崛起的深度學(xué)習(xí),它使得數(shù)據(jù)驅(qū)動的特征提取以及端到端體系結(jié)構(gòu)的靈活設(shè)計成為可能。
那么能否將組合器與深度學(xué)習(xí)相結(jié)合?
ICLR 2020 spotlight 論文《Differentiation of Blackbox Combinatorial Solvers》探討了這一問題。
論文下載地址:https://arxiv.org/abs/1912.02175
在論文中,作者試著將組合求解器無縫融入深度神經(jīng)網(wǎng)絡(luò),并在魔獸爭霸最短路徑問題、最小損失完美匹配問題以及旅行商問題中進(jìn)行了測試。測試結(jié)果顯示,其組合求解器+深度學(xué)習(xí)的方法達(dá)到的效果比傳統(tǒng)的方法要好。
另外,論文的一作Marin Vlastelica,在Medium上撰文介紹了這篇論文的主要思想,雷鋒網(wǎng) AI科技評論作了有刪改的編譯,以下是原文請欣賞~
機(jī)器學(xué)習(xí)的研究現(xiàn)狀表明,基于深度學(xué)習(xí)的現(xiàn)代方法與傳統(tǒng)的人工智能方法確實(shí)存在不一致的地方。深度學(xué)習(xí)在計算機(jī)視覺、強(qiáng)化學(xué)習(xí)、自然語言處理等領(lǐng)域的特征提取方面有著強(qiáng)大的功能。雖然如此,但其在組合泛化問題(combinatorial generalization)上一直是研究者所詬病。
例如,將地圖作為輸入從而在 Google Maps 上預(yù)測最快路線的最短路徑的規(guī)劃問題;(Min,Max)-Cut 問題、最小損失完美匹配問題(Min-Cost Perfect Matching)、旅行商問題、圖匹配問題等等。如果單獨(dú)解決上述每一個問題,我們有很多工具可以選擇:你可以用C語言,可以使用更通用的 MIP(mixed integer programming)求解器。當(dāng)然求解器需要考慮輸入空間問題,畢竟它需要定義良好的結(jié)構(gòu)化輸入。雖然組合問題已經(jīng)成為機(jī)器學(xué)習(xí)研究領(lǐng)域的關(guān)注點(diǎn),但對此類問題的研究力度尚且不足。
這也不是說研究者不重視組合泛化問題,畢竟它仍然是智能系統(tǒng)的關(guān)鍵挑戰(zhàn)之一。理想情況下,研究者能夠以端對端方式,通過強(qiáng)大的函數(shù)逼近器(如神經(jīng)網(wǎng)絡(luò))將豐富的特征提取與高效的組合求解器結(jié)合起來。這也正是論文《Differentiation of Blackbox Combinatorial Solvers》中所實(shí)現(xiàn)的,另外,這篇論文獲得了很高的評審分?jǐn)?shù),并入選為 ICLR 2020 spotlight 論文。文章接下來的部分,并不是在試圖改進(jìn)求解器,而是要將函數(shù)逼近和現(xiàn)有求解器協(xié)同使用。
假設(shè)黑盒求解器(blackbox solver)是一個可以輕松插入深度學(xué)習(xí)的結(jié)構(gòu)模塊。
將連續(xù)輸入到離散輸出之間的映射作為求解器的方式,另外,連續(xù)輸入可以是圖邊的權(quán)重,離散輸出可以是最短路徑、選定的圖邊。其中,映射的定義如下
求解器可以將最小化一些損失函數(shù)c(ω,y),這些損失函數(shù)可以是路徑的長度。用公式這種優(yōu)化問題表示如下:
上式中,w為神經(jīng)網(wǎng)絡(luò)的輸出,也就是神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)的某種表示,例如可以是圖邊權(quán)重的某個向量。在最短路徑問題、旅行商問題中,ω可以用來作出正確的問題描述。優(yōu)化問題的關(guān)鍵是最小化損失函數(shù),現(xiàn)在的問題是損失函數(shù)是分段表示的,也就是說存在跳躍間斷點(diǎn)。這意味著對于表示 ω,該函數(shù)的梯度幾乎處處為 0,并且在跳躍間斷點(diǎn)處,梯度尚未被定義。目前,利用求解器松弛(solver relaxation)的方法能夠解決這個問題,但會損失最優(yōu)性。論文中提出了一種不影響求解器最優(yōu)性的方法。即對原始目標(biāo)函數(shù)的分段處用仿射插值來定義,另外插值由超參數(shù) λ 控制,如下圖所示:
如上所示,函數(shù)圖像的黑色部分是原函數(shù)給出的值,橙色部分是利用插值法給出的值。最小值沒有變化。
當(dāng)然,f的域是多維的。因此,對于同一個f的取值,可以有多個w相對應(yīng)。也就是說輸入的ω的集合是一個多面體,輸出的f可以是相同的值。自然地,在 f 的域中有許多這樣的多面體。超參數(shù) λ 有效地通過擾動求解器輸入 ω 來使多面體偏移。定義了分段仿射目標(biāo)的插值器 g 將多面體的偏移邊界與原始邊界相連。
如下圖所示,取值 f(y2) 的多面體邊界偏移至了取值 f(y1) 處。這也直觀地解釋了為什么更傾向使用較大的超參數(shù)λ。偏移量必須足夠大才能獲得提供有用梯度的插值器g
首先,定義一個擾動優(yōu)化問題的解決方案,其中擾動由超參數(shù)λ控制,公式如下:
如果假設(shè)損失函數(shù)c(ω,y)是y和ω之間的點(diǎn)積,則可以定義插值目標(biāo):
損失函數(shù)的線性度并不像乍一看那樣有限制性。例如,在邊選擇問題中,損失函數(shù)要考慮所有邊權(quán)重的和,具體事例參考旅行商問題和最短路徑問題。
雷鋒網(wǎng)注:如上圖所示,插值隨著超參數(shù)λ的變化而變化
使用該方法,可以通過修改反向傳播來計算梯度,從而消除經(jīng)典組合求解器和深度學(xué)習(xí)之間的不一致性。
def forward(ctx, w_):
"""
ctx: Context for backward pass
w_: Estimated problem weights
"""
y_ = solver(w_)
# Save context for backward pass
ctx.w_ = w_
ctx.y_ = y_ return y_
在前向傳播中,只需給嵌入求解器提供 ω,然后將解向前傳遞。此外,我們保存了 ω 和在前向傳播中計算得到的解 y_。
def backward(ctx, grad):
"""
ctx: Context from forward pass
"""
w = ctx.w_ + lmda*grad # Calculate perturbed weights
y_lmda = solver(w)
return -(ctx.y_ - y_lmda)/lmda
在后向傳遞中,用超參數(shù)λ的反向傳播梯度來擾動 ω,并取先前解與擾動問題解之間的差值
計算插值梯度的計算開銷取決于求解器,額外的開銷有兩次,一次是在前向傳播過程中調(diào)用的一次求解器,另一次是在后向傳播過程中調(diào)用的一次求解器。
為了驗(yàn)證該方法,設(shè)計了具有一定程度復(fù)雜度的合成任務(wù)進(jìn)行驗(yàn)證。
另外,因?yàn)楹唵蔚谋O(jiān)督學(xué)習(xí)方法無法泛化至沒有見過的數(shù)據(jù),所以在下面的任務(wù)中,已經(jīng)證明了此方法對于組合泛化的必要性。
對于最短路徑問題,測試任務(wù)為魔獸爭霸,訓(xùn)練集包括《魔獸爭霸 II》地圖,任務(wù)目標(biāo)為地圖對應(yīng)的最短路徑問題。具體而言,測試集包含了未知的《魔獸爭霸 II》地圖。地圖本身編碼為K*K網(wǎng)格。卷積神經(jīng)網(wǎng)絡(luò)的輸入是地圖,輸出地圖是頂點(diǎn)的損失,然后將該損失作為求解器的輸入。最后,求解器(Dijkstra 最短路徑算法)以指示矩陣的形式在地圖上輸出最短路徑。
在訓(xùn)練的開始,神經(jīng)網(wǎng)絡(luò)不知道如何為地圖的圖塊分配正確的損失,但是使用組合求解器+深度學(xué)習(xí)能夠得到正確的成本,從而找到正確的最短路徑。下列直方圖表明,相比于 ResNet 的傳統(tǒng)監(jiān)督訓(xùn)練方法,此方法的組合泛化能力更棒。
在最小損失完美匹配問題上,使用的數(shù)據(jù)集是MNIST,任務(wù)目標(biāo)是輸出 MNIST 數(shù)字組成網(wǎng)格的最小損失完美匹配。具體而言,在此問題上,選擇的邊應(yīng)該讓所有的頂點(diǎn)都能夠恰好被包含一次,另外還能夠讓損失之和最小。另外,網(wǎng)格中的每個單元都包含一個 MNIST 數(shù)字,該數(shù)字是圖中具備垂直和水平方向鄰近點(diǎn)的一個節(jié)點(diǎn)。最后,邊的損失由垂直向下或水平向右的兩位數(shù)字決定。
求解器輸出匹配中所選邊的指示向量。右側(cè)的匹配損失為 348(水平為 46 + 12,垂直為 27 + 45 + 40 + 67 + 78 + 33)。
在下面這張性能圖上,我們可以清晰看到在神經(jīng)網(wǎng)絡(luò)中嵌入真實(shí)的完美匹配求解器能夠達(dá)到更好的效果。
在旅行商問題中,訓(xùn)練數(shù)據(jù)集是國旗(即原始表示)和對應(yīng)首都的最優(yōu)旅行線路。神經(jīng)網(wǎng)絡(luò)的輸出是各個國家首都的最佳旅行線路。神經(jīng)網(wǎng)絡(luò)在訓(xùn)練的過程,最重要的學(xué)習(xí)首都位置的隱表示。包含K個國家的訓(xùn)練示例如下圖所示。
將各個國家的國旗輸入卷積神經(jīng)網(wǎng)絡(luò),然后網(wǎng)絡(luò)輸出最優(yōu)旅行線路。
在下面的動畫中,也可以看到神經(jīng)網(wǎng)絡(luò)訓(xùn)練期間各國首都在全球范圍內(nèi)的位置。
起初,位置是隨機(jī)分布的,但經(jīng)過訓(xùn)練后,神經(jīng)網(wǎng)絡(luò)不僅學(xué)習(xí)輸出正確的TSP旅行線路,而且學(xué)習(xí)輸出正確的表示,即各個首都的正確3D坐標(biāo)。值得注意的是,這僅僅是通過在監(jiān)督訓(xùn)練過程中使用 Hamming 距離損失,以及對網(wǎng)絡(luò)輸出使用 Gurobi 中的 MIP 實(shí)現(xiàn)的。
實(shí)際上,已經(jīng)證明在求解器損失函數(shù)的某些假設(shè)下,可以通過黑盒組合求解器傳播梯度。這能夠讓傳統(tǒng)有監(jiān)督方法的標(biāo)準(zhǔn)神經(jīng)網(wǎng)絡(luò)架構(gòu)實(shí)現(xiàn)的組合泛化能力。
深度學(xué)習(xí)+組合求解器的學(xué)習(xí)方法能夠在一些需要組合推理的現(xiàn)實(shí)問題上得到廣泛的應(yīng)用。然而問題在于求解器損失的線性這一假設(shè)前提上,在此假設(shè)下我們究竟可以走多遠(yuǎn)?未來工作的重點(diǎn)以及問題在于我們能否學(xué)習(xí)到組合問題的潛在約束,例如 MIP 組合問題。
參考文獻(xiàn)
Vlastelica, Marin, et al. “Differentiation of Blackbox Combinatorial Solvers” arXiv preprint arXiv:1912.02175 (2019). (http://bit.ly/35IowfE)
Rolínek, Michal, et al. “Optimizing Rank-based Metrics with Blackbox Differentiation.” arXiv preprint arXiv:1912.03500 (2019). (http://bit.ly/35EXIMN)
https://towardsdatascience.com/the-fusion-of-deep-learning-and-combinatorics-4d0112a74fa7
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。