0
本文作者: 叢末 | 2019-12-30 10:28 |
作者 | 孫天祥
原文標(biāo)題:稀疏共享:當(dāng)多任務(wù)學(xué)習(xí)遇見彩票假設(shè)
本文介紹了復(fù)旦大學(xué)邱錫鵬團(tuán)隊(duì)在AAAI 2020 上錄用的一篇關(guān)于多任務(wù)學(xué)習(xí)的工作:《Learning Sparse Sharing: Architectures for Mltiple Tasks》,這篇文章提出了一種新的參數(shù)共享機(jī)制:稀疏共享。這種共享機(jī)制能夠同時(shí)解決目前主流的三種共享機(jī)制(硬共享、軟共享、分層共享)的限制問(wèn)題。目前這篇文章已經(jīng)開源。
多任務(wù)學(xué)習(xí)(Multi-Task Learning)是一種聯(lián)合多個(gè)任務(wù)同時(shí)學(xué)習(xí)來(lái)增強(qiáng)模型表示和泛化能力的一種手段,目前大都通過(guò)參數(shù)共享來(lái)實(shí)現(xiàn)多任務(wù)學(xué)習(xí)。因此,很多多任務(wù)學(xué)習(xí)的工作都集中在尋找更好的參數(shù)共享機(jī)制上。已有的工作提出了很多參數(shù)共享策略,其中使用的較多的有硬共享,軟共享,分層共享,另外還有一些比較新穎的值得探索的共享機(jī)制,比如梯度共享,元共享等等。這里簡(jiǎn)要介紹使用較多的三種共享機(jī)制(硬共享、軟共享、分層共享)來(lái)引出本文的動(dòng)機(jī)。硬共享是目前應(yīng)用最為廣泛的共享機(jī)制,它把多個(gè)任務(wù)的數(shù)據(jù)表示嵌入到同一個(gè)語(yǔ)義空間中,再為每個(gè)任務(wù)使用一任務(wù)特定層提取任務(wù)特定表示。硬共享實(shí)現(xiàn)起來(lái)非常簡(jiǎn)單,適合處理有較強(qiáng)相關(guān)性的任務(wù),但遇到弱相關(guān)任務(wù)時(shí)常常表現(xiàn)很差。軟共享為每個(gè)任務(wù)都學(xué)習(xí)一個(gè)網(wǎng)絡(luò),但每個(gè)任務(wù)的網(wǎng)絡(luò)都可以訪問(wèn)其他任務(wù)對(duì)應(yīng)網(wǎng)絡(luò)中的信息,例如表示、梯度等。軟共享機(jī)制非常靈活,不需要對(duì)任務(wù)相關(guān)性做任何假設(shè),但是由于為每個(gè)任務(wù)分配一個(gè)網(wǎng)絡(luò),常常需要增加很多參數(shù)。分層共享是在網(wǎng)絡(luò)的低層做較簡(jiǎn)單的任務(wù),在高層做較困難的任務(wù)。分層共享比硬共享要更靈活,同時(shí)所需的參數(shù)又比軟共享少,但是為多個(gè)任務(wù)設(shè)計(jì)高效的分層結(jié)構(gòu)依賴專家經(jīng)驗(yàn)。本文提出了一種新的參數(shù)共享機(jī)制,稀疏共享(sparse sharing),試圖同時(shí)處理上述三個(gè)限制。
目前常用的參數(shù)共享機(jī)制和本文提出的稀疏共享機(jī)制給定一個(gè)基網(wǎng)絡(luò)和多個(gè)任務(wù)的數(shù)據(jù),稀疏共享可以為每個(gè)任務(wù)從基網(wǎng)絡(luò)中抽取出一個(gè)對(duì)應(yīng)的子網(wǎng)絡(luò)來(lái)處理該任務(wù),這些子網(wǎng)絡(luò)部分重疊,我們的算法可以為強(qiáng)相關(guān)的任務(wù)抽取出相似的子網(wǎng)絡(luò)(具有較高的參數(shù)重疊率),為弱相關(guān)的任務(wù)抽取出為差異較大的子網(wǎng)絡(luò)(具有較低的參數(shù)重疊率)。得到這些子網(wǎng)絡(luò)后,再使用多個(gè)任務(wù)的數(shù)據(jù)聯(lián)合訓(xùn)練。
本文算法分為兩個(gè)階段:(a) 為每個(gè)任務(wù)生成子網(wǎng)絡(luò);(b) 多任務(wù)聯(lián)合訓(xùn)練。1、為每個(gè)任務(wù)生成子網(wǎng)絡(luò)這里生成子網(wǎng)絡(luò)算法使用了獲得ICLR'2019最佳論文獎(jiǎng)的彩票假設(shè)(The Lottery Ticket Hypothesis)中提出的迭代數(shù)量級(jí)剪枝方法。假設(shè)基網(wǎng)絡(luò)參數(shù)為 θε,則任務(wù) t 對(duì)應(yīng)的子網(wǎng)絡(luò)的參數(shù)可以表示為 ,其中
表示元素為 0 或 1 的Mask矩陣。對(duì)每個(gè)任務(wù)獨(dú)立的執(zhí)行迭代剪枝,得到每個(gè)任務(wù)對(duì)應(yīng)的Mask矩陣,也就得到了每個(gè)任務(wù)的子網(wǎng)絡(luò)。
值得注意的是,當(dāng)所有任務(wù)的Mask矩陣 =1 時(shí),稀疏共享等價(jià)于硬共享;考慮兩個(gè)任務(wù),任務(wù)1的Mask矩陣在網(wǎng)絡(luò)的第一層為全 1,第二層為全 0,即
={1,0},任務(wù)2的Mask矩陣為全1,即
=1,則任務(wù)1和任務(wù)2構(gòu)成了分層共享架構(gòu)。因此,硬共享和分層共享都可以視作稀疏共享的特例。
為每個(gè)任務(wù)生成子網(wǎng)絡(luò)
上面的算法為每個(gè)任務(wù)都生成了 Z 個(gè)子網(wǎng)絡(luò),現(xiàn)在需要從中挑選出一個(gè)子網(wǎng)絡(luò)作為最后多任務(wù)訓(xùn)練使用的子網(wǎng)絡(luò)。這里采取了一種簡(jiǎn)單的啟發(fā)式做法,即選擇在驗(yàn)證集上表現(xiàn)最好的子網(wǎng)絡(luò)。
2、多任務(wù)聯(lián)合訓(xùn)練
在得到每個(gè)任務(wù)的子網(wǎng)絡(luò)之后,將其合并也就得到了多任務(wù)稀疏共享結(jié)構(gòu),接著使用多個(gè)任務(wù)的數(shù)據(jù)進(jìn)行聯(lián)合訓(xùn)練:1)隨機(jī)挑選一個(gè)任務(wù) t ;2)為任務(wù) t 隨機(jī)采樣一個(gè)batch數(shù)據(jù);3)將該batch數(shù)據(jù)輸入到任務(wù) t 對(duì)應(yīng)的子網(wǎng)絡(luò)中;4)使用該batch數(shù)據(jù)的梯度更新子網(wǎng)絡(luò)的參數(shù);5)回到 1)。
雖然訓(xùn)練每個(gè)任務(wù)時(shí)都只用到了其對(duì)應(yīng)的子網(wǎng)絡(luò),但子網(wǎng)絡(luò)的一部分參數(shù)可能被多個(gè)任務(wù)同時(shí)共享,因此這部分參數(shù)有機(jī)會(huì)被多個(gè)任務(wù)的訓(xùn)練數(shù)據(jù)更新。這樣,相似的任務(wù)傾向于更新相同的部分參數(shù),使其充分享受多任務(wù)學(xué)習(xí)的收益,同時(shí)差異較大的任務(wù)傾向于更新互相隔離的部分參數(shù),以避免任務(wù)之間互相傷害。
學(xué)習(xí)多任務(wù)稀疏共享架構(gòu)
本文在三個(gè)序列標(biāo)注任務(wù)(POS tagging、NER、Chunking)上進(jìn)行了實(shí)驗(yàn),結(jié)果表明稀疏共享超越了單任務(wù)學(xué)習(xí)、硬共享、軟共享和分層共享的效果,同時(shí)所需參數(shù)量最少。
實(shí)驗(yàn)結(jié)果
值得注意的是,多任務(wù)學(xué)習(xí)并不總能帶來(lái)收益,有時(shí)聯(lián)合學(xué)習(xí)多個(gè)任務(wù)會(huì)對(duì)其中某個(gè)任務(wù)帶來(lái)性能損失,例如上表中陰影部分的數(shù)據(jù)。該現(xiàn)象在遷移學(xué)習(xí)和多任務(wù)學(xué)習(xí)中廣泛存在,常被稱為負(fù)遷移(negative transfer)。然而,在本文的實(shí)驗(yàn)中,稀疏共享并沒(méi)有出現(xiàn)負(fù)遷移現(xiàn)象。為了進(jìn)一步探索稀疏共享在避免負(fù)遷移方面的能力,本文又構(gòu)造了一個(gè)弱相關(guān)多任務(wù)學(xué)習(xí)的場(chǎng)景,該場(chǎng)景包含兩個(gè)任務(wù):
真實(shí)的NER任務(wù);
構(gòu)造的假任務(wù),位置預(yù)測(cè)(position prediction, PP),即讓句子中的每個(gè)單詞預(yù)測(cè)其自身在句中位置。
NER和PP兩個(gè)任務(wù)并無(wú)太大相關(guān)性,結(jié)果表明硬共享框架下同時(shí)學(xué)習(xí)兩個(gè)任務(wù)嚴(yán)重傷害的NER任務(wù)的性能,而稀疏共享則由于參數(shù)隔離避免了負(fù)遷移。
稀疏共享有助于避免負(fù)遷移
另外,本文提供了一種新的衡量任務(wù)相關(guān)性的指標(biāo):參數(shù)重疊率(overlap ratio, OR)。怎么驗(yàn)證OR反映了任務(wù)相關(guān)性呢?本文借助了一個(gè)中間工具:硬共享。硬共享非常適合處理強(qiáng)相關(guān)任務(wù),通常任務(wù)相關(guān)性越弱硬共享效果越差。直覺上,在任務(wù)相關(guān)性越弱的場(chǎng)景下,稀疏共享相比硬共享的提升越多,因此我們可以考察稀疏共享相比硬共享的提升與OR是否正相關(guān)來(lái)驗(yàn)證OR是否可以反映任務(wù)相關(guān)性。為此,把上述三個(gè)任務(wù)兩兩組合得到三個(gè)多任務(wù)學(xué)習(xí)場(chǎng)景,結(jié)果如下:
參數(shù)重疊率反映了任務(wù)相關(guān)性
目前得到稀疏共享架構(gòu)的方法還存在一些問(wèn)題,比如整個(gè)過(guò)程分為兩階段因此相比其他共享模式需要的時(shí)間更久,但這篇文章提出的目的主要是提出并驗(yàn)證稀疏共享模式的可行性,而非具體的架構(gòu)學(xué)習(xí)方法。我們正在,也歡迎其他研究者探索更高效的端到端的稀疏分享架構(gòu)學(xué)習(xí)方法。
雷鋒網(wǎng) AI 科技評(píng)論報(bào)道 雷鋒網(wǎng)雷鋒網(wǎng)
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。