0
本文作者: 蔣寶尚 | 2020-01-14 17:43 |
人腦顯然是人工智能追求的最高標(biāo)準(zhǔn)。
畢竟人腦使得人類(lèi)擁有了連續(xù)學(xué)習(xí)的能力以及情境依賴(lài)學(xué)習(xí)的能力。
這種可以在新的環(huán)境中不斷吸收新的知識(shí)和根據(jù)不同的環(huán)境靈活調(diào)整自己的行為的能力,也正是深度學(xué)習(xí)系統(tǒng)與人腦相差甚遠(yuǎn)的重要原因。
想讓傳統(tǒng)深度學(xué)習(xí)系統(tǒng)獲得連續(xù)學(xué)習(xí)能力,最重要的是克服人工神經(jīng)網(wǎng)絡(luò)會(huì)出現(xiàn)的“災(zāi)難性遺忘”問(wèn)題,即一旦使用新的數(shù)據(jù)集去訓(xùn)練已有的模型,該模型將會(huì)失去對(duì)原數(shù)據(jù)集識(shí)別的能力。
換句話說(shuō)就是:讓神經(jīng)網(wǎng)絡(luò)在學(xué)習(xí)新知識(shí)的同時(shí)保留舊知識(shí)。
前段時(shí)間,來(lái)自蘇黎世聯(lián)邦理工學(xué)院以及蘇黎世大學(xué)的研究團(tuán)隊(duì)發(fā)表了一篇名為《超網(wǎng)絡(luò)的連續(xù)學(xué)習(xí)》(Continual learning with hypernetworks)的研究。提出了任務(wù)條件化的超網(wǎng)絡(luò)(基于任務(wù)屬性生成目標(biāo)模型權(quán)重的網(wǎng)絡(luò))。該方法能夠有效克服災(zāi)難性的遺忘問(wèn)題。
具體來(lái)說(shuō),該方法能夠幫助在針對(duì)多個(gè)任務(wù)訓(xùn)練網(wǎng)絡(luò)時(shí),有效處理災(zāi)難性的遺忘問(wèn)題。除了在標(biāo)準(zhǔn)持續(xù)學(xué)習(xí)基準(zhǔn)測(cè)試中獲得最先進(jìn)的性能外,長(zhǎng)期的附加實(shí)驗(yàn)任務(wù)序列顯示,任務(wù)條件超網(wǎng)絡(luò)(task-conditioned hypernetworks )表現(xiàn)出非常大的保留先前記憶的能力。
在蘇黎世聯(lián)邦理工學(xué)院以及蘇黎世大學(xué)的這項(xiàng)工作中,最重要的是對(duì)超網(wǎng)絡(luò)(hypernetworks)的應(yīng)用,在介紹超網(wǎng)絡(luò)的連續(xù)學(xué)習(xí)之前,雷鋒網(wǎng) AI科技評(píng)論先對(duì)超網(wǎng)絡(luò)做一下介紹。hyperNetwork是一個(gè)非常有名的網(wǎng)絡(luò),簡(jiǎn)單說(shuō)就是用一個(gè)網(wǎng)絡(luò)來(lái)生成另外一個(gè)網(wǎng)絡(luò)的參數(shù)。
工作原理是:用一個(gè)hypernetwork輸入訓(xùn)練集數(shù)據(jù),然后輸出對(duì)應(yīng)模型的參數(shù),最好的輸出是這些參數(shù)能夠使得在測(cè)試數(shù)據(jù)集上取得好的效果。簡(jiǎn)單來(lái)說(shuō)hypernetwork其實(shí)就是一個(gè)meta network。雷鋒網(wǎng) AI科技評(píng)認(rèn)為傳統(tǒng)的做法是用訓(xùn)練集直接訓(xùn)練這個(gè)模型,但是如果使用hypernetwork則不用訓(xùn)練,拋棄反向傳播與梯度下降,直接輸出參數(shù),這等價(jià)于hypernetwork學(xué)會(huì)了如何學(xué)習(xí)圖像識(shí)別。
論文下載見(jiàn)文末
在《hypernetwork》這篇論文中,作者使用 hyperNetwork 生成 RNN 的權(quán)重,發(fā)現(xiàn)能為 LSTM 生成非共享權(quán)重,并在字符級(jí)語(yǔ)言建模、手寫(xiě)字符生成和神經(jīng)機(jī)器翻譯等序列建模任務(wù)上實(shí)現(xiàn)最先進(jìn)的結(jié)果。超網(wǎng)絡(luò)采用一組包含有關(guān)權(quán)重結(jié)構(gòu)的信息的輸入,并生成該層的權(quán)重,如下圖所示。
超網(wǎng)絡(luò)生成前饋網(wǎng)絡(luò)的權(quán)重:黑色連接和參數(shù)與主網(wǎng)絡(luò)相關(guān)聯(lián),而橙色連接和參數(shù)與超網(wǎng)絡(luò)相關(guān)聯(lián)。
在整個(gè)工作中,雷鋒網(wǎng) AI科技評(píng)發(fā)現(xiàn)作者首先假設(shè)輸入的數(shù)據(jù),......
是可以被儲(chǔ)存的,并能夠使用輸入的數(shù)據(jù)計(jì)算
。另外,可以將未使用的數(shù)據(jù)和已經(jīng)使用過(guò)數(shù)據(jù)進(jìn)行混合來(lái)避免遺忘。假設(shè)F(X,Θ)是模型,那么混合后的數(shù)據(jù)集為{(
,
),。。。,(
,
),(
,
)},其中其中Y?(T)是由模型f(.,
)生成的一組合成目標(biāo)。然而存儲(chǔ)數(shù)據(jù)顯然違背了連續(xù)學(xué)習(xí)的原則,所以在在論文中,作者提出了一種新的元模型fh(
,
)做為解決方案,新的解決方案能夠?qū)㈥P(guān)注點(diǎn)從單個(gè)的數(shù)據(jù)輸入輸出轉(zhuǎn)向參數(shù)集{
},并實(shí)現(xiàn)非儲(chǔ)存的要求。這個(gè)元模型稱(chēng)為任務(wù)條件超網(wǎng)絡(luò),主要思想是建立任務(wù)
和權(quán)重
的映射關(guān)系,能夠降維處理數(shù)據(jù)集的存儲(chǔ),大大節(jié)省內(nèi)存。
在《超網(wǎng)絡(luò)的連續(xù)學(xué)習(xí)》這篇論文中,模型部分主要有3個(gè)部分,第一部分是任務(wù)條件超網(wǎng)絡(luò)。首先,超網(wǎng)絡(luò)會(huì)將目標(biāo)模型參數(shù)化,即不是直接學(xué)習(xí)特定模型的參數(shù),而是學(xué)習(xí)元模型的參數(shù),從而元模型會(huì)輸出超網(wǎng)絡(luò)的權(quán)重,也就是說(shuō)超網(wǎng)絡(luò)只是權(quán)重生成器。
圖a:正則化后的超網(wǎng)絡(luò)生成目標(biāo)網(wǎng)絡(luò)權(quán)重參數(shù);圖b:迭代地使用較小的組塊超網(wǎng)絡(luò)產(chǎn)生目標(biāo)網(wǎng)絡(luò)權(quán)重。
然后利用帶有超網(wǎng)絡(luò)的連續(xù)學(xué)習(xí)輸出正則化。在論文中,作者使用兩步優(yōu)化過(guò)程來(lái)引入記憶保持型超網(wǎng)絡(luò)輸出約束。首先,計(jì)算?Θh(?Θh的計(jì)算原則基于優(yōu)化器的選擇,本文中作者使用Adam),即找到能夠最小化損失函數(shù)的參數(shù)。損失函數(shù)表達(dá)式如下圖所示:
注:Θ? h是模型學(xué)習(xí)之前的超網(wǎng)絡(luò)的參數(shù);?Θh為外生變量;βoutput是用來(lái)控制正則化強(qiáng)度的參數(shù)。
然后考慮模型的,它就像
一樣。在算法的每一個(gè)學(xué)習(xí)步驟中,需要及時(shí)更新,并使損失函數(shù)最小化。在學(xué)習(xí)任務(wù)之后,保存最終
e并將其添加到集合{
}。
模型的第二部分是用分塊的超網(wǎng)絡(luò)進(jìn)行模型壓縮。超網(wǎng)絡(luò)產(chǎn)生目標(biāo)神經(jīng)網(wǎng)絡(luò)的整個(gè)權(quán)重集。然而,超網(wǎng)絡(luò)可以迭代調(diào)用,在每一步只需分塊填充目標(biāo)模型中的一部分。這表明允許應(yīng)用較小的可重復(fù)使用的超網(wǎng)絡(luò)。有趣的是,利用分塊超網(wǎng)絡(luò)可以在壓縮狀態(tài)下解決任務(wù),其中學(xué)習(xí)參數(shù)(超網(wǎng)絡(luò)的那些)的數(shù)量實(shí)際上小于目標(biāo)網(wǎng)絡(luò)參數(shù)的數(shù)量。
為了避免在目標(biāo)網(wǎng)絡(luò)的各個(gè)分區(qū)之間引入權(quán)重共享,作者引入塊嵌入的集合{} 作為超網(wǎng)絡(luò)的附加輸入。因此,目標(biāo)網(wǎng)絡(luò)參數(shù)的全集Θ_trgt=[
,,,
]是通過(guò)在
上迭代而產(chǎn)生的,在這過(guò)程中保持
不變。這樣,超網(wǎng)絡(luò)可以每個(gè)塊上產(chǎn)生截然不同的權(quán)重。另外,為了簡(jiǎn)化訓(xùn)練過(guò)程,作者對(duì)所有任務(wù)使用一組共享的塊嵌入。
模型的第三部分:上下文無(wú)關(guān)推理:未知任務(wù)標(biāo)識(shí)(context-free inference: unknown task identity)。從輸入數(shù)據(jù)的角度確定要解決的任務(wù)。超網(wǎng)絡(luò)需要任務(wù)嵌入輸入來(lái)生成目標(biāo)模型權(quán)重。在某些連續(xù)學(xué)習(xí)的應(yīng)用中,由于任務(wù)標(biāo)識(shí)是明確的,或者可以容易地從上下文線索中推斷,因此可以立即選擇合適的嵌入。在其他情況下,選擇合適的嵌入則不是那么容易。
作者在論文中討論了連續(xù)學(xué)習(xí)中利用任務(wù)條件超網(wǎng)絡(luò)的兩種不同策略。
策略一:依賴(lài)于任務(wù)的預(yù)測(cè)不確定性。神經(jīng)網(wǎng)絡(luò)模型在處理分布外的數(shù)據(jù)方面越來(lái)越可靠。對(duì)于分類(lèi)目標(biāo)分布,理想情況下為不可見(jiàn)數(shù)據(jù)產(chǎn)生平坦的高熵輸出,反之,為分布內(nèi)數(shù)據(jù)產(chǎn)生峰值的低熵響應(yīng)。這提出了第一種簡(jiǎn)單的任務(wù)推理方法(HNET+ENT),即給定任務(wù)標(biāo)識(shí)未知的輸入模式,選擇預(yù)測(cè)不確定性最小的任務(wù)嵌入,并用輸出分布熵量化。
策略二:當(dāng)生成模型可用時(shí),可以通過(guò)將當(dāng)前任務(wù)數(shù)據(jù)與過(guò)去合成的數(shù)據(jù)混合來(lái)規(guī)避災(zāi)難性遺忘。除了保護(hù)生成模型本身,合成數(shù)據(jù)還可以保護(hù)另一模型。這種策略實(shí)際上往往是連續(xù)學(xué)習(xí)中最優(yōu)的解決方案。受這些成功經(jīng)驗(yàn)的啟發(fā),作者探索用回放網(wǎng)絡(luò)(replay network)來(lái)增強(qiáng)深度學(xué)習(xí)系統(tǒng)。
合成回放(Synthetic replay)是一種強(qiáng)大但并不完美的連續(xù)學(xué)習(xí)機(jī)制,因?yàn)樯赡J饺菀灼?,錯(cuò)誤往往會(huì)隨著時(shí)間的推移而積累和放大。作者在一系列關(guān)鍵觀察的基礎(chǔ)上決定:就像目標(biāo)網(wǎng)絡(luò)一樣,重放模型可以由超網(wǎng)絡(luò)指定,并允許使用輸出正則化公式。而不是使用模型自己的回放數(shù)據(jù)。因此,在這種結(jié)合的方法中,合成重放和任務(wù)條件元建模同時(shí)起作用,避免災(zāi)難性遺忘。
作者使用MNIST、CIFAR10和CIFAR-100公共數(shù)據(jù)集對(duì)論文中的方法進(jìn)行了評(píng)估。評(píng)估主要在兩個(gè)方面:(1)研究任務(wù)條件超網(wǎng)絡(luò)在三種連續(xù)學(xué)習(xí)環(huán)境下的記憶保持能力,(2)研究順序?qū)W習(xí)任務(wù)之間的信息傳遞。具體的在評(píng)估實(shí)驗(yàn)中,作者根據(jù)任務(wù)標(biāo)識(shí)是否明確出了三種連續(xù)學(xué)習(xí)場(chǎng)景:CL1,任務(wù)標(biāo)識(shí)明確;CL2,任務(wù)標(biāo)識(shí)不明確,并不需明確推斷;CL3,任務(wù)標(biāo)識(shí)可以明確推斷出來(lái)。另外作者在MNIST數(shù)據(jù)集上構(gòu)建了一個(gè)全連通的網(wǎng)絡(luò),其中超參的設(shè)定參考了van de Ven & Tolias (2019)論文中的方法。在CIFAR實(shí)驗(yàn)中選擇了ResNet-32作為目標(biāo)神經(jīng)網(wǎng)絡(luò)。
van de Ven & Tolias (2019):
Gido M. van de Ven and Andreas S. Tolias. Three scenarios for continual learning. arXiv preprint arXiv:1904.07734, 2019.
為了進(jìn)一步說(shuō)明論文中的方法,作者考慮了四個(gè)連續(xù)學(xué)習(xí)分類(lèi)問(wèn)題中的基準(zhǔn)測(cè)試:非線性回歸,PermutedMNIST,Split-MNIST,Split CIFAR-10/100。
非線性回歸的結(jié)果如下:
注:圖a:有輸出正則化的任務(wù)條件超網(wǎng)絡(luò)可以很容易地對(duì)遞增次數(shù)的多項(xiàng)式序列建模,同時(shí)能夠達(dá)到連續(xù)學(xué)習(xí)的效果。圖b:和多任務(wù)直接訓(xùn)練的目標(biāo)網(wǎng)絡(luò)找到的解決方案類(lèi)似。圖c:循序漸進(jìn)地學(xué)習(xí)會(huì)導(dǎo)致遺忘。
在PermutedMNIST中,作者并對(duì)輸入的圖像數(shù)據(jù)的像素進(jìn)行隨機(jī)排列。發(fā)現(xiàn)在CL1中,任務(wù)條件超網(wǎng)絡(luò)在長(zhǎng)度為T(mén)=10的任務(wù)序列中表現(xiàn)最佳。在PermutedMNIST上任務(wù)條件超網(wǎng)絡(luò)的表現(xiàn)非常好,對(duì)比來(lái)看突觸智能(Synaptic Intelligence) ,online EWC,以及深度生成回放( deep generative replay)方法有差別,具體來(lái)說(shuō)突觸智能和DGR+distill會(huì)發(fā)生退化,online EWC不會(huì)達(dá)到非常高的精度,如下圖a所示。綜合考慮壓縮比率與任務(wù)平均測(cè)試集準(zhǔn)確性,超網(wǎng)絡(luò)允許的壓縮模型,即使目標(biāo)網(wǎng)絡(luò)的參數(shù)數(shù)量超過(guò)超網(wǎng)絡(luò)模型的參數(shù)數(shù)量,精度依然保持恒定,如下圖b所示。
Split-MNIST作為另一個(gè)比較流行的連續(xù)學(xué)習(xí)的基準(zhǔn)測(cè)試,在Split-MNIST中將各個(gè)數(shù)字有序配對(duì),并形成五個(gè)二進(jìn)制分類(lèi)任務(wù),結(jié)果發(fā)現(xiàn)任務(wù)條件超網(wǎng)絡(luò)整體性能表現(xiàn)最好。另外在split MNIST問(wèn)題上任務(wù)重疊,能夠跨任務(wù)傳遞信息,并發(fā)現(xiàn)該算法收斂到可以產(chǎn)生同時(shí)解決舊任務(wù)和新任務(wù)的目標(biāo)模型參數(shù)的超網(wǎng)絡(luò)配置。如下圖所示
圖a:即使在低維度空間下仍然有著高分類(lèi)性能,同時(shí)沒(méi)有發(fā)生遺忘。圖b:即使最后一個(gè)任務(wù)占據(jù)著高性能區(qū)域,并在遠(yuǎn)離嵌入向量的情況下退化情況仍然可接受,其性能仍然較高。
在CIFAR實(shí)驗(yàn)中,作者選擇了ResNet-32作為目標(biāo)神經(jīng)網(wǎng)絡(luò),在實(shí)驗(yàn)過(guò)程中,作者發(fā)現(xiàn)運(yùn)用任務(wù)條件超網(wǎng)絡(luò)基本完全消除了遺忘,另外還會(huì)發(fā)生前向信息反饋,這也就是說(shuō)與從初始條件單獨(dú)學(xué)習(xí)每個(gè)任務(wù)相比,來(lái)自以前任務(wù)的知識(shí)可以讓網(wǎng)絡(luò)表現(xiàn)更好。
綜上,在論文中作者提出了一種新的連續(xù)學(xué)習(xí)的神經(jīng)網(wǎng)絡(luò)應(yīng)用模型--任務(wù)條件超網(wǎng)絡(luò),該方法具有可靈活性和通用性,作為獨(dú)立的連續(xù)學(xué)習(xí)方法可以和生成式回放結(jié)合使用。該方法能夠?qū)崿F(xiàn)較長(zhǎng)的記憶壽命,并能將信息傳輸?shù)轿磥?lái)的任務(wù),能夠滿足連續(xù)學(xué)習(xí)的兩個(gè)基本特性。
參考文獻(xiàn):
HYPERNETWORKS:
https://arxiv.org/pdf/1609.09106.pdf
CONTINUAL LEARNING WITH HYPERNETWORKS
https://arxiv.org/pdf/1906.00695.pdf
https://mp.weixin.qq.com/s/hZcVRraZUe9xA63CaV54Yg
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見(jiàn)轉(zhuǎn)載須知。