0
本文作者: 楊曉凡 | 2019-10-11 19:50 |
雷鋒網(wǎng) AI 科技評(píng)論按:今天我們要介紹對(duì)無(wú)限大的神經(jīng)網(wǎng)絡(luò)的探索。這個(gè)故事來(lái)自 CMU 杜少雷和胡威撰寫的博客《Ultra-Wide Deep Nets and Neural Tangent Kernel (NTK)》,雷鋒網(wǎng) AI 科技評(píng)論編譯。讓我們從頭講起。
機(jī)器學(xué)習(xí)界有一條流傳已久的規(guī)矩:你需要在模型的訓(xùn)練誤差和泛化能力之間謹(jǐn)慎地做出取舍。衡量泛化能力,有一個(gè)很便捷的指標(biāo)是看看模型在訓(xùn)練集和測(cè)試集上的誤差相差多大,那么,一個(gè)較小的模型通常很難在訓(xùn)練集上做到很小的訓(xùn)練誤差,不過(guò)這個(gè)誤差和測(cè)試集上的測(cè)試誤差在同一水平;換用更大的模型以后一般都可以得到更小的訓(xùn)練誤差,但是測(cè)試誤差往往會(huì)比訓(xùn)練誤差大不少;也就是說(shuō),太大、太小的模型都無(wú)法得到較好的測(cè)試誤差。所以大家總結(jié),要尋找模型復(fù)雜度的甜點(diǎn)(「sweet spot」),復(fù)雜度要足夠大,能夠達(dá)到足夠低的訓(xùn)練誤差;復(fù)雜度也不能太大,避免測(cè)試誤差比訓(xùn)練誤差大太多。下面這個(gè)經(jīng)典的 U 型曲線就是根據(jù)這個(gè)理論繪制的 —— 測(cè)試誤差隨著模型復(fù)雜度的增加而先減小、再增大。
不過(guò),隨著深度神經(jīng)網(wǎng)絡(luò)之類的高度復(fù)雜、高度過(guò)參數(shù)化(over-parameterized)的模型得到廣泛研究和使用,大家發(fā)現(xiàn)它們經(jīng)常可以在訓(xùn)練數(shù)據(jù)集上做到接近 0 的誤差,然后還能在測(cè)試數(shù)據(jù)上發(fā)揮出令人驚訝地好的表現(xiàn),如今的網(wǎng)絡(luò)大小也就隨著越來(lái)越大。Belkin 等人(https://arxiv.org/abs/1812.11118 )用一種新的雙峰曲線描述了這個(gè)現(xiàn)象,他們?cè)诮?jīng)典的 U 型曲線的右邊繼續(xù)延伸,描繪出:當(dāng)模型的復(fù)雜度繼續(xù)增大,越過(guò)了「模型復(fù)雜度足以完全擬合訓(xùn)練數(shù)據(jù)」(比如可以用模型為數(shù)據(jù)點(diǎn)取差值)的那個(gè)點(diǎn)之后,測(cè)試誤差就可以持續(xù)下降!有趣的是,越大的模型往往能給出越好的結(jié)果,已經(jīng)跳出了以往的「復(fù)雜度甜點(diǎn)」的考慮范疇。如下圖。
有人懷疑深度學(xué)習(xí)中使用的優(yōu)化算法,比如梯度下降、隨機(jī)梯度下降以及各種變體,其實(shí)起到了隱式地限制模型復(fù)雜度的效果(也就是說(shuō),雖然整個(gè)模型中的參數(shù)很多,但其中真正獨(dú)立有效的參數(shù)只有一部分),也就避免了過(guò)擬合,避免了測(cè)試誤差和訓(xùn)練誤差相差過(guò)大。
另外,「越大的模型往往能給出越好的結(jié)果」,所以很自然地有人會(huì)問(wèn)「如果我們有一個(gè)無(wú)限大的網(wǎng)絡(luò),它的表現(xiàn)會(huì)如何?」按照上面那張雙峰圖,答案就對(duì)應(yīng)著隱藏在圖像的最右側(cè)的東西。上一年中這是一個(gè)熱門的研究問(wèn)題:神經(jīng)網(wǎng)絡(luò)的寬度,也就是卷積層中的通道數(shù)目、或者全連接隱層中的神經(jīng)元數(shù)目,趨近于無(wú)窮大的時(shí)候會(huì)得到怎樣的表現(xiàn)。
乍看上去這個(gè)問(wèn)題是無(wú)解的,要做實(shí)驗(yàn)的話,有再多的計(jì)算資源也無(wú)法訓(xùn)練一個(gè)真正「無(wú)限大」的網(wǎng)絡(luò);而要理論分析的話,有限大小的網(wǎng)絡(luò)都還沒(méi)有研究清楚呢。不過(guò),數(shù)學(xué)和物理領(lǐng)域一直都有研究「趨于無(wú)限大」從而得到新的見(jiàn)解的慣例,研究「趨于無(wú)限大」也在理論上更容易一點(diǎn)。
研究深度神經(jīng)網(wǎng)絡(luò)的學(xué)者們可能還記得無(wú)限寬的神經(jīng)網(wǎng)絡(luò)和核方法之間的聯(lián)系,25 年前 Neal (https://www.cs.toronto.edu/~radford/pin.abstract.html)闡述過(guò),Lee 等(https://openreview.net/forum?id=B1EA-M-0Z)和 Matthews 等(https://arxiv.org/abs/1804.11271)近期也做了拓展。這些核可以對(duì)應(yīng)所有參數(shù)都隨機(jī)選擇、且只有最上層(分類器層)用梯度下降訓(xùn)練過(guò)的的無(wú)限寬的深度神經(jīng)網(wǎng)絡(luò)。具體來(lái)說(shuō),如果我們用 θ 表示網(wǎng)絡(luò)中的參數(shù)集,x 表示網(wǎng)絡(luò)的輸入,就可以把輸出表示為 f(θ,x);接著,W 是 θ 之上的初始化分布(通常是帶有一定縮放的高斯分布),那么對(duì)應(yīng)的核就是 ,其中 x、x' 是兩個(gè)輸入。
那么更常見(jiàn)的「網(wǎng)絡(luò)中所有的層都是訓(xùn)練過(guò)的」這種情況呢?Jacot 等(https://arxiv.org/abs/1806.07572)近期發(fā)現(xiàn)這也和一種核有關(guān)系,他們把它稱為 neural tangent kernel(NTK,神經(jīng)正切核),它的形式可以寫作 。
NTK 和之前提出的核的關(guān)鍵區(qū)別在于,NTK 是由網(wǎng)絡(luò)的輸出相對(duì)于網(wǎng)絡(luò)參數(shù)的梯度之間的內(nèi)乘積來(lái)定義的;其中的梯度來(lái)自訓(xùn)練網(wǎng)絡(luò)時(shí)使用的梯度下降算法。概括地說(shuō),對(duì)于一個(gè)梯度下降訓(xùn)練出的足夠?qū)挼纳疃壬窠?jīng)網(wǎng)絡(luò),下面這個(gè)結(jié)論是成立的:
一個(gè)正確地隨機(jī)初始化的、足夠?qū)挼?/strong>、由具有無(wú)窮小步長(zhǎng)大?。ㄒ簿褪翘荻攘?nbsp;gradient flow)的梯度下降訓(xùn)練的深度神經(jīng)網(wǎng)絡(luò),和一個(gè)帶有 NTK 的確定性核回歸預(yù)測(cè)器是等效的。
這個(gè)結(jié)論在 Jacot 等最初的論文(https://arxiv.org/abs/1806.07572)中就基本確立了,不過(guò)他們要求網(wǎng)絡(luò)的各個(gè)層依次趨近于無(wú)限大。在 Sanjeev Arora, 杜少雷, 胡威, Zhiyuan Li, Ruslan Salakhutdinov and Ruosong Wang 等人最新的論文(https://arxiv.org/abs/1904.11955)中,他們把這個(gè)結(jié)果做了進(jìn)一步的改進(jìn),讓它對(duì)非對(duì)稱環(huán)境也適用,也就是每層的寬度不用依次變大,只需要都高過(guò)某個(gè)有限的閾值就可以。
杜少雷和胡威
詳細(xì)的推導(dǎo)過(guò)程在論文(https://arxiv.org/abs/1904.11955)中有介紹,這里我們只簡(jiǎn)單提一下。作者們?cè)跇?biāo)準(zhǔn)的有監(jiān)督學(xué)習(xí)環(huán)境下考慮這個(gè)問(wèn)題,通過(guò)最小化訓(xùn)練數(shù)據(jù)上的二次方損失的方式訓(xùn)練神經(jīng)網(wǎng)絡(luò)。經(jīng)過(guò)一系列推導(dǎo),作者們得到了含有網(wǎng)絡(luò)梯度項(xiàng)的核矩陣的表達(dá)式。
不過(guò)到這里為止作者們還沒(méi)有使用「網(wǎng)絡(luò)非常寬」的這個(gè)條件。當(dāng)網(wǎng)絡(luò)足夠?qū)挄r(shí),他們推導(dǎo)的核可以逼近某個(gè)確定性的固定核,也就是前面提到的 neural tangent kernel(NTK,神經(jīng)正切核)。不過(guò),確定「到底多寬才是足夠?qū)挕剐枰恍┘僭O(shè)和技巧,在這篇論文中作者們最終得到的是只要網(wǎng)絡(luò)的每一層的寬度各自大于某個(gè)閾值就可以,要比更早的結(jié)果中要求每一層寬度逐漸更趨近于無(wú)窮大的限制更弱一些。
最終作者們推導(dǎo)出訓(xùn)練后的無(wú)限寬神經(jīng)網(wǎng)絡(luò)和 NTK 是等效的。詳細(xì)的推導(dǎo)過(guò)程請(qǐng)見(jiàn)論文原文。
在證明了無(wú)限寬的神經(jīng)網(wǎng)絡(luò)和 NTK 等效之后,我們就有機(jī)會(huì)實(shí)際看看無(wú)限寬的神經(jīng)網(wǎng)絡(luò)的表現(xiàn) —— 只要測(cè)試對(duì)應(yīng)的使用 NTK 的核回歸預(yù)測(cè)器就可以了!作者們?cè)跇?biāo)準(zhǔn)的圖像分類測(cè)試集 CIFAR-10 上進(jìn)行了測(cè)試。由于這是基于圖像的任務(wù),想要得到好的結(jié)果一定少不了卷積結(jié)構(gòu)的參與,所以作者們也推導(dǎo)了卷積 NTK,并和標(biāo)準(zhǔn)的卷積網(wǎng)絡(luò)進(jìn)行對(duì)比。分類準(zhǔn)確率對(duì)比如下:
圖中 CNN-V 是不帶有池化的、正常寬度的 CNN,CNTK-V 是對(duì)應(yīng)的卷積 NTK。作者們也測(cè)試了帶有全局平均池化(GAP)的網(wǎng)絡(luò),也就是 CNN-GAP 和 CNTK-GAP。在所有實(shí)驗(yàn)中都沒(méi)有使用批量標(biāo)準(zhǔn)化(batch normalization)、數(shù)據(jù)增強(qiáng)等等訓(xùn)練技巧,只使用 SGD 訓(xùn)練 CNN,以及 CNTK 使用核回歸的解析方程。
實(shí)驗(yàn)表明 CNTK 其實(shí)是很強(qiáng)的核方法。實(shí)驗(yàn)中最強(qiáng)的是帶有全局平均池化的、11 層的 CNTK,得到了 77.43% 的分類準(zhǔn)確率。目前為止最強(qiáng)的完全基于核的方法來(lái)自 Novak 等(https://openreview.net/forum?id=B1g30j0qF7),而 CNTK 要比他們的準(zhǔn)確率高出超過(guò) 10%。而且 CNTK 和正常 CNN 的表現(xiàn)都很接近,也就是說(shuō)在 CIFAR-10 上超寬(無(wú)限寬)的 CNN 是可以取得不錯(cuò)的表現(xiàn)的。
另外有趣的是,全局池化不僅(如預(yù)期地)顯著提升了正常 CNN 的準(zhǔn)確率,也同樣明顯提升了 CNTK 的準(zhǔn)確率。也許提高神經(jīng)網(wǎng)絡(luò)表現(xiàn)的許多技巧要比我們目前認(rèn)識(shí)到的更通用一些,它們可能也對(duì)核方法有效。
想要理解為什么過(guò)度參數(shù)化的深度神經(jīng)網(wǎng)絡(luò)還能有好得驚人的表現(xiàn)的確是一個(gè)很有挑戰(zhàn)的理論問(wèn)題。不過(guò)現(xiàn)在起碼我們已經(jīng)對(duì)一類非常寬的神經(jīng)網(wǎng)絡(luò)有了更多的了解:可以用 NTK 來(lái)表示它們。不過(guò)還有一個(gè)未解的困難是,關(guān)于核方法的經(jīng)典泛化理論沒(méi)法給出泛化能力的現(xiàn)實(shí)上下界。好在我們至少知道對(duì)核方法的更深的理解也可以幫助我們理解神經(jīng)網(wǎng)絡(luò)了。
另外我們還算探索出了一個(gè)新的方向,那就是把不同的神經(jīng)網(wǎng)絡(luò)架構(gòu)、訓(xùn)練技巧轉(zhuǎn)換到核方法上來(lái),并檢查它們的表現(xiàn)。作者們發(fā)現(xiàn)全局平均池化可以大幅提升核方法的表現(xiàn),那很有可能 BN、drop-out、最大池化之類的方法也能在核方法中發(fā)揮作用;反過(guò)來(lái),我們也可以嘗試把 RNN、圖神經(jīng)網(wǎng)絡(luò)、Transformer 之類的神經(jīng)網(wǎng)絡(luò)轉(zhuǎn)換成核方法。
以及那個(gè)核心的問(wèn)題:有限寬和無(wú)限寬的神經(jīng)網(wǎng)絡(luò)之間確實(shí)有性能區(qū)別,如何解釋這種區(qū)別的原因也是重要的理論研究課題。
原論文:On Exact Computation with an Infinitely Wide Neural Net,無(wú)限寬的神經(jīng)網(wǎng)絡(luò)的精確計(jì)算
論文地址:https://arxiv.org/abs/1904.11955
本文編譯自技術(shù)博客 http://www.offconvex.org/2019/10/03/NTK/,雷鋒網(wǎng) AI 科技評(píng)論編譯
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見(jiàn)轉(zhuǎn)載須知。