0
本文作者: 我在思考中 | 2021-08-06 11:00 | 專題:ICML 2019 |
本文提出了一個(gè)新的損失函數(shù),混合交叉熵?fù)p失(Mixed CE),用于替代在機(jī)器翻譯的兩種訓(xùn)練方式(Teacher Forcing和 Scheduled Sampling)里常用的交叉熵?fù)p失函數(shù)(CE)。
Mixed CE實(shí)現(xiàn)簡(jiǎn)單,計(jì)算開銷基本和標(biāo)準(zhǔn)的CE持平,并且在多個(gè)翻譯數(shù)據(jù)的多種測(cè)試集上表現(xiàn)優(yōu)于CE。這篇文章我們簡(jiǎn)要介紹Mixed CE的背景和一些主要的實(shí)驗(yàn)結(jié)果。
文章和附錄:http://proceedings.mlr.press/v139/li21n.html
代碼:https://github.com/haorannlp/mix
背景
本節(jié)簡(jiǎn)單介紹一下 Teacher Forcing和 Scheduled Sampling 的背景。
Teacher Forcing[1]訓(xùn)練方式指的是當(dāng)我們?cè)谟?xùn)練一個(gè)自回歸模型時(shí)(比如RNN,LSTM,或者Transformer的decoder部分),我們需要將真實(shí)的目標(biāo)序列(比如我們想要翻譯的句子)作為自回歸模型的輸入,以便模型能夠收斂的更快更好。通常在Teacher Forcing(TF)這種訓(xùn)練方式下,模型使用的損失函數(shù)是CE:
值得注意的是,機(jī)器翻譯(MT)本身是一個(gè)一對(duì)多的映射問題,比如同樣一句中文可以翻譯成不同的英文,而使用CE的時(shí)候,因?yàn)槊總€(gè)單詞使用一個(gè)one-hot encoding去表示的,這種情況下MT是被我們當(dāng)作了一個(gè)一對(duì)一的映射問題。這種方式可能會(huì)限制模型的泛化能力,因?yàn)槭褂肅E的模型學(xué)到的條件分布 更接近于一個(gè)one-hot encoding,而非數(shù)據(jù)真實(shí)的條件分布
。但不可否認(rèn)的是,即使模型用CE訓(xùn)練,它在實(shí)踐中也取得了很好的效果。CE在實(shí)踐中的成功意味著模型學(xué)習(xí)到的條件分布
可能也包含著部分真實(shí)分布
的信息。我們能不能在訓(xùn)練的時(shí)候從 提取
的信息呢?這就是我們的Mixed CE所要完成的目標(biāo)。
雖然TF訓(xùn)練方式簡(jiǎn)單,但它會(huì)導(dǎo)致exposure bias的問題,即在訓(xùn)練階段模型使用的輸入來自于真實(shí)數(shù)據(jù)分布,而在測(cè)試階段模型每一時(shí)刻使用的輸入來自于模型上一時(shí)刻的預(yù)測(cè)結(jié)果,這兩個(gè)輸入分布之間的差異被稱作exposure bias。
因此,研究者們進(jìn)而提出了Scheduled Sampling[2](SS)。在自回歸模型每一時(shí)刻的輸入不再是來自于真實(shí)數(shù)據(jù),而是隨機(jī)從真實(shí)數(shù)據(jù)或模型上一時(shí)刻的輸出中采樣一個(gè)點(diǎn)作為輸入。這種方法的本質(zhì)是希望通過在訓(xùn)練階段混入模型自身的預(yù)測(cè)結(jié)果作為輸入,減小其與測(cè)試階段輸入數(shù)據(jù)分布的差異。也就是說,SS所做的是讓訓(xùn)練輸入數(shù)據(jù)分布近似測(cè)試輸入數(shù)據(jù)的分布,從而減輕exposure bias。
而另一種減輕exposure bias的思想是,即使訓(xùn)練和測(cè)試階段輸入來自不同的分布,只要模型的輸出是相似的,這種輸入的差異性也就無關(guān)緊要了。我們的Mixed CE就是想要達(dá)到這樣的目標(biāo)。
需要注意的一點(diǎn)是,SS本來是用于RNN的,但由于Transformer的興起,后續(xù)的研究者們提出了一些改進(jìn)的SS以便適用于Transformer decoder在訓(xùn)練階段能夠并行計(jì)算的特性。即運(yùn)行Transformer deocder兩次,第一次輸入真實(shí)的數(shù)據(jù),然后從t時(shí)刻的輸出分布里采樣一個(gè)數(shù)據(jù)點(diǎn)
, 最終得到一個(gè)序列
。接著,將
和目標(biāo)序列
里面的元素隨機(jī)進(jìn)行混合,得到新序列
。然后把
作為decoder的輸入,按照正常方式進(jìn)行訓(xùn)練。
方法
我們提出的Mixed CE可以同時(shí)用于TF和SS兩種訓(xùn)練方式中。
在TF中,為了應(yīng)用MixedCE,我們首先做出一個(gè)假設(shè):如果模型當(dāng)前預(yù)測(cè)的概率最大的token和目標(biāo)token不一致,那我們認(rèn)為預(yù)測(cè)的token很有可能是目標(biāo)token的同義詞或者同義詞的一部分。
我們做出這個(gè)假設(shè)是因?yàn)樵趯?shí)際中的平行語料庫里,同樣一個(gè)源語言的單詞在目標(biāo)語言會(huì)有多種不同的翻譯方式。如果這些不同的翻譯在語料庫里出現(xiàn)的頻率相差不多,那么在預(yù)測(cè)該源語言單詞時(shí),模型非常有可能給這些不同的翻譯相似的概率,而概率最大的那種翻譯方式恰好是目標(biāo)token的同義詞。
具體來說,Mixed CE的公式如下:
這里的是模型在當(dāng)前時(shí)刻模型預(yù)測(cè)的最有可能的結(jié)果,而根據(jù)我們之前的假設(shè),有可能是的同義詞。Mixed CE通過以 作為目標(biāo)進(jìn)行優(yōu)化,有效利用了
中含有的真實(shí)分布
的信息。同時(shí),這里的
,
是當(dāng)前訓(xùn)練的iteration,total_iter代表了總的訓(xùn)練輪數(shù)。隨著訓(xùn)練的進(jìn)行,模型的效果越來越好,
會(huì)不斷增大,Mixed CE中第二項(xiàng)的權(quán)重也就越大。
在SS中,Mixed CE的形式類似于上述公式:
這里的 是對(duì)第一次運(yùn)行Transformer decoder的輸出進(jìn)行g(shù)reedy采樣的結(jié)果。第一次運(yùn)行Transformer decoder時(shí)的輸入是真實(shí)的目標(biāo)序列,而第二次運(yùn)行時(shí)的輸入是序列
。通過優(yōu)化這個(gè)目標(biāo)函數(shù)的第二部分,無論模型輸入是
還是
,模型總是能夠輸出相似的結(jié)果,也就是說,模型能夠忽略輸入分布的差異,從而減輕了exposure bias的問題。
值得注意的是,相比于CE,Mixed CE在訓(xùn)練期間只增加很少的計(jì)算量,額外的計(jì)算量來自于尋找模型預(yù)測(cè)結(jié)果的最大值。
實(shí)驗(yàn)
由于篇幅有限,我們只列出幾個(gè)重要的實(shí)驗(yàn)結(jié)果,更詳細(xì)的實(shí)驗(yàn)結(jié)果可以在原文中找到。
在TF訓(xùn)練方式中,我們?cè)赪MT’14 En-De上的multi-reference test set上面進(jìn)行了測(cè)試。在這個(gè)測(cè)試集中,每個(gè)源語言的句子有10種不同的reference translation,我們利用beam search為每一句源語言句子生成10個(gè)candidate translations,并且計(jì)算了每一個(gè)Hypothesis相對(duì)于每一種reference translation的BLEU分?jǐn)?shù),并且取它們的平均值或者最大值。結(jié)果如下:
我們可以看到Mixed CE在所有reference上面始終優(yōu)于標(biāo)準(zhǔn)CE。
另外,我們也在一個(gè)paraphrased reference set(WMT’19 En-De)上面進(jìn)行了測(cè)試。這個(gè)測(cè)試集里面的每一個(gè)reference都是經(jīng)過語言專家的改寫,改寫后的句子結(jié)構(gòu)和詞匯的使用都變得更復(fù)雜。結(jié)果如下:
Mixed CE仍然優(yōu)于CE。通常在這個(gè)測(cè)試集上,0.3~0.4 BLEU的提升就表明效果就很顯著了。
由于Mixed CE的形式類似于label smoothing,所以我們也具體比較了Mixed CE和label smoothing。我們利用Pairwise-BLEU(PB)衡量模型輸出分布的平滑程度,PB越大,輸出分布越陡峭,反之則越平滑。結(jié)果如下:
可以看到,加入label smoothing之后,輸出分布變得更加平滑,而Mixed CE使得輸出分布變得更加陡峭。所以Mixed CE和label smoothing是不同的。并且從BLEU的分?jǐn)?shù)可以看出, label smoothing和Mixed CE并不是一個(gè)互斥的關(guān)系,兩者共用效果會(huì)更好。
在SS中,我們以SS和word oracle(SS的一個(gè)變種)作為Baseline。結(jié)果如下:
可以看到Mixed CE總是好于CE。此外,我們?cè)谡撐闹羞€提供了ablation study,以確認(rèn)Mixed CE中的第二項(xiàng)對(duì)性能的提升是必不可少的。
此外,我們?cè)诟戒浿幸擦谐隽艘恍╆P(guān)于domain adaptation的初步實(shí)驗(yàn),歡迎大家繼續(xù)探索Mixed CE在其他領(lǐng)域的應(yīng)用。
結(jié)論
在本文中我們提出了Mixed CE,用于替換在teacher forcing和scheduled sampling中使用CE損失函數(shù)。實(shí)驗(yàn)表明在teacher forcing里,Mixed CE在multi-reference, paraphrased reference set上面的表現(xiàn)總是優(yōu)于CE。同時(shí),我們也對(duì)比了label smoothing和Mixed CE,發(fā)現(xiàn)它們對(duì)輸出分布的影響是不同的。在scheduled sampling當(dāng)中,Mixed CE能夠更有效的減輕exposure bias的影響。
掃碼加入ICML2021交流群:
若二維碼過期或群內(nèi)滿200人時(shí),添加小助手微信(AIyanxishe3),備注ICML2021拉你進(jìn)群。
雷鋒網(wǎng)雷鋒網(wǎng)雷鋒網(wǎng)
雷峰網(wǎng)特約稿件,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。
本專題其他文章