2
本文作者: 亞萌 | 2017-02-06 16:37 |
雷鋒網(wǎng)按:本文作者鄭華濱,原載于知乎。雷鋒網(wǎng)已獲轉(zhuǎn)載授權(quán)。
在GAN的相關(guān)研究如火如荼甚至可以說(shuō)是泛濫的今天,一篇新鮮出爐的arXiv論文《Wassertein GAN》卻在Reddit的Machine Learning頻道火了,連Goodfellow都在帖子里和大家熱烈討論,這篇論文究竟有什么了不得的地方呢?
要知道自從2014年Ian Goodfellow提出以來(lái),GAN就存在著訓(xùn)練困難、生成器和判別器的loss無(wú)法指示訓(xùn)練進(jìn)程、生成樣本缺乏多樣性等問(wèn)題。從那時(shí)起,很多論文都在嘗試解決,但是效果不盡人意,比如最有名的一個(gè)改進(jìn)DCGAN依靠的是對(duì)判別器和生成器的架構(gòu)進(jìn)行實(shí)驗(yàn)枚舉,最終找到一組比較好的網(wǎng)絡(luò)架構(gòu)設(shè)置,但是實(shí)際上是治標(biāo)不治本,沒(méi)有徹底解決問(wèn)題。而今天的主角Wasserstein GAN(下面簡(jiǎn)稱(chēng)WGAN)成功地做到了以下爆炸性的幾點(diǎn):
徹底解決GAN訓(xùn)練不穩(wěn)定的問(wèn)題,不再需要小心平衡生成器和判別器的訓(xùn)練程度
基本解決了collapse mode的問(wèn)題,確保了生成樣本的多樣性
訓(xùn)練過(guò)程中終于有一個(gè)像交叉熵、準(zhǔn)確率這樣的數(shù)值來(lái)指示訓(xùn)練的進(jìn)程,這個(gè)數(shù)值越小代表GAN訓(xùn)練得越好,代表生成器產(chǎn)生的圖像質(zhì)量越高(如題圖所示)
以上一切好處不需要精心設(shè)計(jì)的網(wǎng)絡(luò)架構(gòu),最簡(jiǎn)單的多層全連接網(wǎng)絡(luò)就可以做到
那以上好處來(lái)自哪里?這就是令人拍案叫絕的部分了——實(shí)際上作者整整花了兩篇論文,在第一篇《Towards Principled Methods for Training Generative Adversarial Networks》里面推了一堆公式定理,從理論上分析了原始GAN的問(wèn)題所在,從而針對(duì)性地給出了改進(jìn)要點(diǎn);在這第二篇《Wassertein GAN》里面,又再?gòu)倪@個(gè)改進(jìn)點(diǎn)出發(fā)推了一堆公式定理,最終給出了改進(jìn)的算法實(shí)現(xiàn)流程,而改進(jìn)后相比原始GAN的算法實(shí)現(xiàn)流程卻只改了四點(diǎn):
判別器最后一層去掉sigmoid
生成器和判別器的loss不取log
每次更新判別器的參數(shù)之后把它們的絕對(duì)值截?cái)嗟讲怀^(guò)一個(gè)固定常數(shù)c
不要用基于動(dòng)量的優(yōu)化算法(包括momentum和Adam),推薦RMSProp,SGD也行
算法截圖如下:
改動(dòng)是如此簡(jiǎn)單,效果卻驚人地好,以至于Reddit上不少人在感嘆:就這樣?沒(méi)有別的了? 太簡(jiǎn)單了吧!這些反應(yīng)讓我想起了一個(gè)頗有年頭的雞湯段子,說(shuō)是一個(gè)工程師在電機(jī)外殼上用粉筆劃了一條線排除了故障,要價(jià)一萬(wàn)美元——畫(huà)一條線,1美元;知道在哪畫(huà)線,9999美元。上面這四點(diǎn)改進(jìn)就是作者M(jìn)artin Arjovsky劃的簡(jiǎn)簡(jiǎn)單單四條線,對(duì)于工程實(shí)現(xiàn)便已足夠,但是知道在哪劃線,背后卻是精巧的數(shù)學(xué)分析,而這也是本文想要整理的內(nèi)容。
本文內(nèi)容分為五個(gè)部分:
原始GAN究竟出了什么問(wèn)題?(此部分較長(zhǎng))
WGAN之前的一個(gè)過(guò)渡解決方案
Wasserstein距離的優(yōu)越性質(zhì)
從Wasserstein距離到WGAN
總結(jié)
理解原文的很多公式定理需要對(duì)測(cè)度論、 拓?fù)鋵W(xué)等數(shù)學(xué)知識(shí)有所掌握,本文會(huì)從直觀的角度對(duì)每一個(gè)重要公式進(jìn)行解讀,有時(shí)通過(guò)一些低維的例子幫助讀者理解數(shù)學(xué)背后的思想,所以不免會(huì)失于嚴(yán)謹(jǐn),如有引喻不當(dāng)之處,歡迎在評(píng)論中指出。
以下簡(jiǎn)稱(chēng)《Wassertein GAN》為“WGAN本作”,簡(jiǎn)稱(chēng)《Towards Principled Methods for Training Generative Adversarial Networks》為“WGAN前作”。
WGAN源碼實(shí)現(xiàn):martinarjovsky/WassersteinGAN
回顧一下,原始GAN中判別器要最小化如下?lián)p失函數(shù),盡可能把真實(shí)樣本分為正例,生成樣本分為負(fù)例:
(公式1 )
其中是真實(shí)樣本分布,
是由生成器產(chǎn)生的樣本分布。對(duì)于生成器,Goodfellow一開(kāi)始提出來(lái)一個(gè)損失函數(shù),后來(lái)又提出了一個(gè)改進(jìn)的損失函數(shù),分別是
(公式2)
(公式3)
后者在WGAN兩篇論文中稱(chēng)為“the - log D alternative”或“the - log D trick”。WGAN前作分別分析了這兩種形式的原始GAN各自的問(wèn)題所在,下面分別說(shuō)明。
一句話概括:判別器越好,生成器梯度消失越嚴(yán)重。WGAN前作從兩個(gè)角度進(jìn)行了論證,第一個(gè)角度是從生成器的等價(jià)損失函數(shù)切入的。
首先從公式1可以得到,在生成器G固定參數(shù)時(shí)最優(yōu)的判別器D應(yīng)該是什么。對(duì)于一個(gè)具體的樣本,它可能來(lái)自真實(shí)分布也可能來(lái)自生成分布,它對(duì)公式1損失函數(shù)的貢獻(xiàn)是
令其關(guān)于的導(dǎo)數(shù)為0,得
化簡(jiǎn)得最優(yōu)判別器為:
(公式4)
這個(gè)結(jié)果從直觀上很容易理解,就是看一個(gè)樣本來(lái)自真實(shí)分布和生成分布的可能性的相對(duì)比例。如果
且
,最優(yōu)判別器就應(yīng)該非常自信地給出概率0;如果
,說(shuō)明該樣本是真是假的可能性剛好一半一半,此時(shí)最優(yōu)判別器也應(yīng)該給出概率0.5。
然而GAN訓(xùn)練有一個(gè)trick,就是別把判別器訓(xùn)練得太好,否則在實(shí)驗(yàn)中生成器會(huì)完全學(xué)不動(dòng)(loss降不下去),為了探究背后的原因,我們就可以看看在極端情況——判別器最優(yōu)時(shí),生成器的損失函數(shù)變成什么。給公式2加上一個(gè)不依賴(lài)于生成器的項(xiàng),使之變成
注意,最小化這個(gè)損失函數(shù)等價(jià)于最小化公式2,而且它剛好是判別器損失函數(shù)的反。代入最優(yōu)判別器即公式4,再進(jìn)行簡(jiǎn)單的變換可以得到
(公式5)
變換成這個(gè)樣子是為了引入Kullback–Leibler divergence(簡(jiǎn)稱(chēng)KL散度)和Jensen-Shannon divergence(簡(jiǎn)稱(chēng)JS散度)這兩個(gè)重要的相似度衡量指標(biāo),后面的主角之一Wasserstein距離,就是要來(lái)吊打它們兩個(gè)的。所以接下來(lái)介紹這兩個(gè)重要的配角——KL散度和JS散度:
(公式6)
(公式7)
于是公式5就可以繼續(xù)寫(xiě)成
(公式8)
到這里讀者可以先喘一口氣,看看目前得到了什么結(jié)論:根據(jù)原始GAN定義的判別器loss,我們可以得到最優(yōu)判別器的形式;而在最優(yōu)判別器的下,我們可以把原始GAN定義的生成器loss等價(jià)變換為最小化真實(shí)分布與生成分布
之間的JS散度。我們?cè)接?xùn)練判別器,它就越接近最優(yōu),最小化生成器的loss也就會(huì)越近似于最小化
和
之間的JS散度。
問(wèn)題就出在這個(gè)JS散度上。我們會(huì)希望如果兩個(gè)分布之間越接近它們的JS散度越小,我們通過(guò)優(yōu)化JS散度就能將“拉向”
,最終以假亂真。這個(gè)希望在兩個(gè)分布有所重疊的時(shí)候是成立的,但是如果兩個(gè)分布完全沒(méi)有重疊的部分,或者它們重疊的部分可忽略(下面解釋什么叫可忽略),它們的JS散度是多少呢?
答案是,因?yàn)閷?duì)于任意一個(gè)x只有四種可能:
且
且
且
且
第一種對(duì)計(jì)算JS散度無(wú)貢獻(xiàn),第二種情況由于重疊部分可忽略所以貢獻(xiàn)也為0,第三種情況對(duì)公式7右邊第一個(gè)項(xiàng)的貢獻(xiàn)是,第四種情況與之類(lèi)似,所以最終
。
換句話說(shuō),無(wú)論跟
是遠(yuǎn)在天邊,還是近在眼前,只要它們倆沒(méi)有一點(diǎn)重疊或者重疊部分可忽略,JS散度就固定是常數(shù)
,而這對(duì)于梯度下降方法意味著——梯度為0!此時(shí)對(duì)于最優(yōu)判別器來(lái)說(shuō),生成器肯定是得不到一丁點(diǎn)梯度信息的;即使對(duì)于接近最優(yōu)的判別器來(lái)說(shuō),生成器也有很大機(jī)會(huì)面臨梯度消失的問(wèn)題。
但是與
不重疊或重疊部分可忽略的可能性有多大?不嚴(yán)謹(jǐn)?shù)拇鸢甘牵悍浅4?。比較嚴(yán)謹(jǐn)?shù)拇鸢甘牵?strong>當(dāng)
與
的支撐集(support)是高維空間中的低維流形(manifold)時(shí),
與
重疊部分測(cè)度(measure)為0的概率為1。
不用被奇怪的術(shù)語(yǔ)嚇得關(guān)掉頁(yè)面,雖然論文給出的是嚴(yán)格的數(shù)學(xué)表述,但是直觀上其實(shí)很容易理解。首先簡(jiǎn)單介紹一下這幾個(gè)概念:
支撐集(support)其實(shí)就是函數(shù)的非零部分子集,比如ReLU函數(shù)的支撐集就是,一個(gè)概率分布的支撐集就是所有概率密度非零部分的集合。
流形(manifold)是高維空間中曲線、曲面概念的拓廣,我們可以在低維上直觀理解這個(gè)概念,比如我們說(shuō)三維空間中的一個(gè)曲面是一個(gè)二維流形,因?yàn)樗谋举|(zhì)維度(intrinsic dimension)只有2,一個(gè)點(diǎn)在這個(gè)二維流形上移動(dòng)只有兩個(gè)方向的自由度。同理,三維空間或者二維空間中的一條曲線都是一個(gè)一維流形。
測(cè)度(measure)是高維空間中長(zhǎng)度、面積、體積概念的拓廣,可以理解為“超體積”。
回過(guò)頭來(lái)看第一句話,“當(dāng)與
的支撐集是高維空間中的低維流形時(shí)”,基本上是成立的。原因是GAN中的生成器一般是從某個(gè)低維(比如100維)的隨機(jī)分布中采樣出一個(gè)編碼向量,再經(jīng)過(guò)一個(gè)神經(jīng)網(wǎng)絡(luò)生成出一個(gè)高維樣本(比如64x64的圖片就有4096維)。當(dāng)生成器的參數(shù)固定時(shí),生成樣本的概率分布雖然是定義在4096維的空間上,但它本身所有可能產(chǎn)生的變化已經(jīng)被那個(gè)100維的隨機(jī)分布限定了,其本質(zhì)維度就是100,再考慮到神經(jīng)網(wǎng)絡(luò)帶來(lái)的映射降維,最終可能比100還小,所以生成樣本分布的支撐集就在4096維空間中構(gòu)成一個(gè)最多100維的低維流形,“撐不滿(mǎn)”整個(gè)高維空間。
“撐不滿(mǎn)”就會(huì)導(dǎo)致真實(shí)分布與生成分布難以“碰到面”,這很容易在二維空間中理解:一方面,二維平面中隨機(jī)取兩條曲線,它們之間剛好存在重疊線段的概率為0;另一方面,雖然它們很大可能會(huì)存在交叉點(diǎn),但是相比于兩條曲線而言,交叉點(diǎn)比曲線低一個(gè)維度,長(zhǎng)度(測(cè)度)為0,可忽略。三維空間中也是類(lèi)似的,隨機(jī)取兩個(gè)曲面,它們之間最多就是比較有可能存在交叉線,但是交叉線比曲面低一個(gè)維度,面積(測(cè)度)是0,可忽略。從低維空間拓展到高維空間,就有了如下邏輯:因?yàn)橐婚_(kāi)始生成器隨機(jī)初始化,所以幾乎不可能與
有什么關(guān)聯(lián),所以它們的支撐集之間的重疊部分要么不存在,要么就比
和
的最小維度還要低至少一個(gè)維度,故而測(cè)度為0。所謂“重疊部分測(cè)度為0”,就是上文所言“不重疊或者重疊部分可忽略”的意思。
我們就得到了WGAN前作中關(guān)于生成器梯度消失的第一個(gè)論證:在(近似)最優(yōu)判別器下,最小化生成器的loss等價(jià)于最小化與
之間的JS散度,而由于
與
幾乎不可能有不可忽略的重疊,所以無(wú)論它們相距多遠(yuǎn)JS散度都是常數(shù)
,最終導(dǎo)致生成器的梯度(近似)為0,梯度消失。
接著作者寫(xiě)了很多公式定理從第二個(gè)角度進(jìn)行論證,但是背后的思想也可以直觀地解釋?zhuān)?/p>
首先,與
之間幾乎不可能有不可忽略的重疊,所以無(wú)論它們之間的“縫隙”多狹小,都肯定存在一個(gè)最優(yōu)分割曲面把它們隔開(kāi),最多就是在那些可忽略的重疊處隔不開(kāi)而已。
由于判別器作為一個(gè)神經(jīng)網(wǎng)絡(luò)可以無(wú)限擬合這個(gè)分隔曲面,所以存在一個(gè)最優(yōu)判別器,對(duì)幾乎所有真實(shí)樣本給出概率1,對(duì)幾乎所有生成樣本給出概率0,而那些隔不開(kāi)的部分就是難以被最優(yōu)判別器分類(lèi)的樣本,但是它們的測(cè)度為0,可忽略。
最優(yōu)判別器在真實(shí)分布和生成分布的支撐集上給出的概率都是常數(shù)(1和0),導(dǎo)致生成器的loss梯度為0,梯度消失。
有了這些理論分析,原始GAN不穩(wěn)定的原因就徹底清楚了:判別器訓(xùn)練得太好,生成器梯度消失,生成器loss降不下去;判別器訓(xùn)練得不好,生成器梯度不準(zhǔn),四處亂跑。只有判別器訓(xùn)練得不好不壞才行,但是這個(gè)火候又很難把握,甚至在同一輪訓(xùn)練的前后不同階段這個(gè)火候都可能不一樣,所以GAN才那么難訓(xùn)練。
實(shí)驗(yàn)輔證如下:
WGAN前作Figure 2。先分別將DCGAN訓(xùn)練1,20,25個(gè)epoch,然后固定生成器不動(dòng),判別器重新隨機(jī)初始化從頭開(kāi)始訓(xùn)練,對(duì)于第一種形式的生成器loss產(chǎn)生的梯度可以打印出其尺度的變化曲線,可以看到隨著判別器的訓(xùn)練,生成器的梯度均迅速衰減。注意y軸是對(duì)數(shù)坐標(biāo)軸。
一句話概括:最小化第二種生成器loss函數(shù),會(huì)等價(jià)于最小化一個(gè)不合理的距離衡量,導(dǎo)致兩個(gè)問(wèn)題,一是梯度不穩(wěn)定,二是collapse mode即多樣性不足。WGAN前作又是從兩個(gè)角度進(jìn)行了論證,下面只說(shuō)第一個(gè)角度,因?yàn)閷?duì)于第二個(gè)角度我難以找到一個(gè)直觀的解釋方式,感興趣的讀者還是去看論文吧(逃)。
如前文所說(shuō),Ian Goodfellow提出的“- log D trick”是把生成器loss改成
(公式3)
上文推導(dǎo)已經(jīng)得到在最優(yōu)判別器下
(公式9)
我們可以把KL散度(注意下面是先g后r)變換成含的形式:
(公式10)
由公式3,9,10可得最小化目標(biāo)的等價(jià)變形
注意上式最后兩項(xiàng)不依賴(lài)于生成器G,最終得到最小化公式3等價(jià)于最小化
(公式11)
這個(gè)等價(jià)最小化目標(biāo)存在兩個(gè)嚴(yán)重的問(wèn)題。第一是它同時(shí)要最小化生成分布與真實(shí)分布的KL散度,卻又要最大化兩者的JS散度,一個(gè)要拉近,一個(gè)卻要推遠(yuǎn)!這在直觀上非?;闹嚕跀?shù)值上則會(huì)導(dǎo)致梯度不穩(wěn)定,這是后面那個(gè)JS散度項(xiàng)的毛病。
第二,即便是前面那個(gè)正常的KL散度項(xiàng)也有毛病。因?yàn)镵L散度不是一個(gè)對(duì)稱(chēng)的衡量,與
是有差別的。以前者為例
當(dāng)而時(shí)
,
,
對(duì)貢獻(xiàn)趨近0
當(dāng)而
時(shí),
,
對(duì)貢獻(xiàn)趨近正無(wú)窮
換言之,對(duì)于上面兩種錯(cuò)誤的懲罰是不一樣的,第一種錯(cuò)誤對(duì)應(yīng)的是“生成器沒(méi)能生成真實(shí)的樣本”,懲罰微?。坏诙N錯(cuò)誤對(duì)應(yīng)的是“生成器生成了不真實(shí)的樣本” ,懲罰巨大。第一種錯(cuò)誤對(duì)應(yīng)的是缺乏多樣性,第二種錯(cuò)誤對(duì)應(yīng)的是缺乏準(zhǔn)確性。這一放一打之下,生成器寧可多生成一些重復(fù)但是很“安全”的樣本,也不愿意去生成多樣性的樣本,因?yàn)槟菢右徊恍⌒木蜁?huì)產(chǎn)生第二種錯(cuò)誤,得不償失。這種現(xiàn)象就是大家常說(shuō)的collapse mode。
第一部分小結(jié):在原始GAN的(近似)最優(yōu)判別器下,第一種生成器loss面臨梯度消失問(wèn)題,第二種生成器loss面臨優(yōu)化目標(biāo)荒謬、梯度不穩(wěn)定、對(duì)多樣性與準(zhǔn)確性懲罰不平衡導(dǎo)致mode collapse這幾個(gè)問(wèn)題。
實(shí)驗(yàn)輔證如下:
WGAN前作Figure 3。先分別將DCGAN訓(xùn)練1,20,25個(gè)epoch,然后固定生成器不動(dòng),判別器重新隨機(jī)初始化從頭開(kāi)始訓(xùn)練,對(duì)于第二種形式的生成器loss產(chǎn)生的梯度可以打印出其尺度的變化曲線,可以看到隨著判別器的訓(xùn)練,藍(lán)色和綠色曲線中生成器的梯度迅速增長(zhǎng),說(shuō)明梯度不穩(wěn)定,紅線對(duì)應(yīng)的是DCGAN相對(duì)收斂的狀態(tài),梯度才比較穩(wěn)定。
原始GAN問(wèn)題的根源可以歸結(jié)為兩點(diǎn),一是等價(jià)優(yōu)化的距離衡量(KL散度、JS散度)不合理,二是生成器隨機(jī)初始化后的生成分布很難與真實(shí)分布有不可忽略的重疊。
WGAN前作其實(shí)已經(jīng)針對(duì)第二點(diǎn)提出了一個(gè)解決方案,就是對(duì)生成樣本和真實(shí)樣本加噪聲,直觀上說(shuō),使得原本的兩個(gè)低維流形“彌散”到整個(gè)高維空間,強(qiáng)行讓它們產(chǎn)生不可忽略的重疊。而一旦存在重疊,JS散度就能真正發(fā)揮作用,此時(shí)如果兩個(gè)分布越靠近,它們“彌散”出來(lái)的部分重疊得越多,JS散度也會(huì)越小而不會(huì)一直是一個(gè)常數(shù),于是(在第一種原始GAN形式下)梯度消失的問(wèn)題就解決了。在訓(xùn)練過(guò)程中,我們可以對(duì)所加的噪聲進(jìn)行退火(annealing),慢慢減小其方差,到后面兩個(gè)低維流形“本體”都已經(jīng)有重疊時(shí),就算把噪聲完全拿掉,JS散度也能照樣發(fā)揮作用,繼續(xù)產(chǎn)生有意義的梯度把兩個(gè)低維流形拉近,直到它們接近完全重合。以上是對(duì)原文的直觀解釋。
在這個(gè)解決方案下我們可以放心地把判別器訓(xùn)練到接近最優(yōu),不必?fù)?dān)心梯度消失的問(wèn)題。而當(dāng)判別器最優(yōu)時(shí),對(duì)公式9取反可得判別器的最小loss為
其中和
分別是加噪后的真實(shí)分布與生成分布。反過(guò)來(lái)說(shuō),從最優(yōu)判別器的loss可以反推出當(dāng)前兩個(gè)加噪分布的JS散度。兩個(gè)加噪分布的JS散度可以在某種程度上代表兩個(gè)原本分布的距離,也就是說(shuō)可以通過(guò)最優(yōu)判別器的loss反映訓(xùn)練進(jìn)程!……真的有這樣的好事嗎?
并沒(méi)有,因?yàn)榧釉隞S散度的具體數(shù)值受到噪聲的方差影響,隨著噪聲的退火,前后的數(shù)值就沒(méi)法比較了,所以它不能成為和
距離的本質(zhì)性衡量。
因?yàn)楸疚牡闹攸c(diǎn)是WGAN本身,所以WGAN前作的加噪方案簡(jiǎn)單介紹到這里,感興趣的讀者可以閱讀原文了解更多細(xì)節(jié)。加噪方案是針對(duì)原始GAN問(wèn)題的第二點(diǎn)根源提出的,解決了訓(xùn)練不穩(wěn)定的問(wèn)題,不需要小心平衡判別器訓(xùn)練的火候,可以放心地把判別器訓(xùn)練到接近最優(yōu),但是仍然沒(méi)能夠提供一個(gè)衡量訓(xùn)練進(jìn)程的數(shù)值指標(biāo)。但是WGAN本作就從第一點(diǎn)根源出發(fā),用Wasserstein距離代替JS散度,同時(shí)完成了穩(wěn)定訓(xùn)練和進(jìn)程指標(biāo)的問(wèn)題!
作者未對(duì)此方案進(jìn)行實(shí)驗(yàn)驗(yàn)證。
Wasserstein距離又叫Earth-Mover(EM)距離,定義如下:
(公式12)
解釋如下:是
和
組合起來(lái)的所有可能的聯(lián)合分布的集合,反過(guò)來(lái)說(shuō),
中每一個(gè)分布的邊緣分布都是
和
。對(duì)于每一個(gè)可能的聯(lián)合分布
而言,可以從中采樣
得到一個(gè)真實(shí)樣本
和一個(gè)生成樣本
,并算出這對(duì)樣本的距離
,所以可以計(jì)算該聯(lián)合分布下樣本對(duì)距離的期望值。在所有可能的聯(lián)合分布
中能夠?qū)@個(gè)期望值
取到的下界
,就定義為Wasserstein距離。
直觀上可以把理解為在
這個(gè)“路徑規(guī)劃”下把
這堆“沙土”挪到
“位置”所需的“消耗”,而
就是“最優(yōu)路徑規(guī)劃”下的“最小消耗”,所以才叫Earth-Mover(推土機(jī))距離。
Wasserstein距離相比KL散度、JS散度的優(yōu)越性在于,即便兩個(gè)分布沒(méi)有重疊,Wasserstein距離仍然能夠反映它們的遠(yuǎn)近。WGAN本作通過(guò)簡(jiǎn)單的例子展示了這一點(diǎn)??紤]如下二維空間中的兩個(gè)分布和
,
在線段AB上均勻分布,
在線段CD上均勻分布,通過(guò)控制參數(shù)
可以控制著兩個(gè)分布的距離遠(yuǎn)近。
此時(shí)容易得到(讀者可自行驗(yàn)證)
(突變)
(突變)
(平滑)
KL散度和JS散度是突變的,要么最大要么最小,Wasserstein距離卻是平滑的,如果我們要用梯度下降法優(yōu)化這個(gè)參數(shù),前兩者根本提供不了梯度,Wasserstein距離卻可以。類(lèi)似地,在高維空間中如果兩個(gè)分布不重疊或者重疊部分可忽略,則KL和JS既反映不了遠(yuǎn)近,也提供不了梯度,但是Wasserstein卻可以提供有意義的梯度。
既然Wasserstein距離有如此優(yōu)越的性質(zhì),如果我們能夠把它定義為生成器的loss,不就可以產(chǎn)生有意義的梯度來(lái)更新生成器,使得生成分布被拉向真實(shí)分布嗎?
沒(méi)那么簡(jiǎn)單,因?yàn)閃asserstein距離定義(公式12)中的沒(méi)法直接求解,不過(guò)沒(méi)關(guān)系,作者用了一個(gè)已有的定理把它變換為如下形式
(公式13)
證明過(guò)程被作者丟到論文附錄中了,我們也姑且不管,先看看上式究竟說(shuō)了什么。
首先需要介紹一個(gè)概念——Lipschitz連續(xù)。它其實(shí)就是在一個(gè)連續(xù)函數(shù)上面額外施加了一個(gè)限制,要求存在一個(gè)常數(shù)
使得定義域內(nèi)的任意兩個(gè)元素
和
都滿(mǎn)足
此時(shí)稱(chēng)函數(shù)的Lipschitz常數(shù)為
。
簡(jiǎn)單理解,比如說(shuō)的定義域是實(shí)數(shù)集合,那上面的要求就等價(jià)于
的導(dǎo)函數(shù)絕對(duì)值不超過(guò)
。再比如說(shuō)
就不是Lipschitz連續(xù),因?yàn)樗膶?dǎo)函數(shù)沒(méi)有上界。Lipschitz連續(xù)條件限制了一個(gè)連續(xù)函數(shù)的最大局部變動(dòng)幅度。
公式13的意思就是在要求函數(shù)的Lipschitz常數(shù)
不超過(guò)
的條件下,對(duì)所有可能滿(mǎn)足條件
的取到
的上界,然后再除以
。特別地,我們可以用一組參數(shù)
來(lái)定義一系列可能的函數(shù)
,此時(shí)求解公式13可以近似變成求解如下形式
(公式14)
再用上我們搞深度學(xué)習(xí)的人最熟悉的那一套,不就可以把用一個(gè)帶參數(shù)
的神經(jīng)網(wǎng)絡(luò)來(lái)表示嘛!由于神經(jīng)網(wǎng)絡(luò)的擬合能力足夠強(qiáng)大,我們有理由相信,這樣定義出來(lái)的一系列
雖然無(wú)法囊括所有可能,但是也足以高度近似公式13要求的那個(gè)
了。
最后,還不能忘了滿(mǎn)足公式14中這個(gè)限制。我們其實(shí)不關(guān)心具體的K是多少,只要它不是正無(wú)窮就行,因?yàn)樗皇菚?huì)使得梯度變大
倍,并不會(huì)影響梯度的方向。所以作者采取了一個(gè)非常簡(jiǎn)單的做法,就是限制神經(jīng)網(wǎng)絡(luò)
的所有參數(shù)
的不超過(guò)某個(gè)范圍
,比如
,此時(shí)所有偏導(dǎo)數(shù)
也不會(huì)超過(guò)某個(gè)范圍,所以一定存在某個(gè)不知道的常數(shù)
使得
的局部變動(dòng)幅度不會(huì)超過(guò)它,Lipschitz連續(xù)條件得以滿(mǎn)足。具體在算法實(shí)現(xiàn)中,只需要每次更新完
后把它c(diǎn)lip回這個(gè)范圍就可以了。
到此為止,我們可以構(gòu)造一個(gè)含參數(shù)、最后一層不是非線性激活層的判別器網(wǎng)絡(luò)
,在限制
不超過(guò)某個(gè)范圍的條件下,使得
(公式15)
盡可能取到最大,此時(shí)就會(huì)近似真實(shí)分布與生成分布之間的Wasserstein距離(忽略常數(shù)倍數(shù)
)。注意原始GAN的判別器做的是真假二分類(lèi)任務(wù),所以最后一層是sigmoid,但是現(xiàn)在WGAN中的判別器
做的是近似擬合Wasserstein距離,屬于回歸任務(wù),所以要把最后一層的sigmoid拿掉。
接下來(lái)生成器要近似地最小化Wasserstein距離,可以最小化,由于Wasserstein距離的優(yōu)良性質(zhì),我們不需要擔(dān)心生成器梯度消失的問(wèn)題。再考慮到
的第一項(xiàng)與生成器無(wú)關(guān),就得到了WGAN的兩個(gè)loss。
(公式16,WGAN生成器loss函數(shù))
(公式17,WGAN判別器loss函數(shù))
公式15是公式17的反,可以指示訓(xùn)練進(jìn)程,其數(shù)值越小,表示真實(shí)分布與生成分布的Wasserstein距離越小,GAN訓(xùn)練得越好。
WGAN完整的算法流程已經(jīng)貼過(guò)了,為了方便讀者此處再貼一遍:
上文說(shuō)過(guò),WGAN與原始GAN第一種形式相比,只改了四點(diǎn):
判別器最后一層去掉sigmoid
生成器和判別器的loss不取log
每次更新判別器的參數(shù)之后把它們的絕對(duì)值截?cái)嗟讲怀^(guò)一個(gè)固定常數(shù)c
不要用基于動(dòng)量的優(yōu)化算法(包括momentum和Adam),推薦RMSProp,SGD也行
前三點(diǎn)都是從理論分析中得到的,已經(jīng)介紹完畢;第四點(diǎn)卻是作者從實(shí)驗(yàn)中發(fā)現(xiàn)的,屬于trick,相對(duì)比較“玄”。作者發(fā)現(xiàn)如果使用Adam,判別器的loss有時(shí)候會(huì)崩掉,當(dāng)它崩掉時(shí),Adam給出的更新方向與梯度方向夾角的cos值就變成負(fù)數(shù),更新方向與梯度方向南轅北轍,這意味著判別器的loss梯度是不穩(wěn)定的,所以不適合用Adam這類(lèi)基于動(dòng)量的優(yōu)化算法。作者改用RMSProp之后,問(wèn)題就解決了,因?yàn)镽MSProp適合梯度不穩(wěn)定的情況。
對(duì)WGAN作者做了不少實(shí)驗(yàn)驗(yàn)證,本文只提比較重要的兩點(diǎn)。第一,判別器所近似的Wasserstein距離與生成器的生成圖片質(zhì)量高度相關(guān),如下所示(此即題圖):
第二,WGAN如果用類(lèi)似DCGAN架構(gòu),生成圖片的效果與DCGAN差不多:
但是厲害的地方在于WGAN不用DCGAN各種特殊的架構(gòu)設(shè)計(jì)也能做到不錯(cuò)的效果,比如如果大家一起拿掉Batch Normalization的話,DCGAN就崩了:
如果WGAN和原始GAN都使用多層全連接網(wǎng)絡(luò)(MLP),不用CNN,WGAN質(zhì)量會(huì)變差些,但是原始GAN不僅質(zhì)量變得更差,而且還出現(xiàn)了collapse mode,即多樣性不足:
最后補(bǔ)充一點(diǎn)論文沒(méi)提到,但是我個(gè)人覺(jué)得比較微妙的問(wèn)題。判別器所近似的Wasserstein距離能夠用來(lái)指示單次訓(xùn)練中的訓(xùn)練進(jìn)程,這個(gè)沒(méi)錯(cuò);接著作者又說(shuō)它可以用于比較多次訓(xùn)練進(jìn)程,指引調(diào)參,我倒是覺(jué)得需要小心些。比如說(shuō)我下次訓(xùn)練時(shí)改了判別器的層數(shù)、節(jié)點(diǎn)數(shù)等超參,判別器的擬合能力就必然有所波動(dòng),再比如說(shuō)我下次訓(xùn)練時(shí)改了生成器兩次迭代之間,判別器的迭代次數(shù),這兩種常見(jiàn)的變動(dòng)都會(huì)使得Wasserstein距離的擬合誤差就與上次不一樣。那么這個(gè)擬合誤差的變動(dòng)究竟有多大,或者說(shuō)不同的人做實(shí)驗(yàn)時(shí)判別器的擬合能力或迭代次數(shù)相差實(shí)在太大,那它們之間還能不能直接比較上述指標(biāo),我都是存疑的。
WGAN前作分析了Ian Goodfellow提出的原始GAN兩種形式各自的問(wèn)題,第一種形式等價(jià)在最優(yōu)判別器下等價(jià)于最小化生成分布與真實(shí)分布之間的JS散度,由于隨機(jī)生成分布很難與真實(shí)分布有不可忽略的重疊以及JS散度的突變特性,使得生成器面臨梯度消失的問(wèn)題;第二種形式在最優(yōu)判別器下等價(jià)于既要最小化生成分布與真實(shí)分布直接的KL散度,又要最大化其JS散度,相互矛盾,導(dǎo)致梯度不穩(wěn)定,而且KL散度的不對(duì)稱(chēng)性使得生成器寧可喪失多樣性也不愿喪失準(zhǔn)確性,導(dǎo)致collapse mode現(xiàn)象。
WGAN前作針對(duì)分布重疊問(wèn)題提出了一個(gè)過(guò)渡解決方案,通過(guò)對(duì)生成樣本和真實(shí)樣本加噪聲使得兩個(gè)分布產(chǎn)生重疊,理論上可以解決訓(xùn)練不穩(wěn)定的問(wèn)題,可以放心訓(xùn)練判別器到接近最優(yōu),但是未能提供一個(gè)指示訓(xùn)練進(jìn)程的可靠指標(biāo),也未做實(shí)驗(yàn)驗(yàn)證。
WGAN本作引入了Wasserstein距離,由于它相對(duì)KL散度與JS散度具有優(yōu)越的平滑特性,理論上可以解決梯度消失問(wèn)題。接著通過(guò)數(shù)學(xué)變換將Wasserstein距離寫(xiě)成可求解的形式,利用一個(gè)參數(shù)數(shù)值范圍受限的判別器神經(jīng)網(wǎng)絡(luò)來(lái)最大化這個(gè)形式,就可以近似Wasserstein距離。在此近似最優(yōu)判別器下優(yōu)化生成器使得Wasserstein距離縮小,就能有效拉近生成分布與真實(shí)分布。WGAN既解決了訓(xùn)練不穩(wěn)定的問(wèn)題,也提供了一個(gè)可靠的訓(xùn)練進(jìn)程指標(biāo),而且該指標(biāo)確實(shí)與生成樣本的質(zhì)量高度相關(guān)。作者對(duì)WGAN進(jìn)行了實(shí)驗(yàn)驗(yàn)證。
雷鋒網(wǎng)
相關(guān)文章:
縱覽深度學(xué)習(xí)技術(shù)前沿,Yoshua Bengio為你解讀如何創(chuàng)造人類(lèi)水平的AI(附PPT)
Google首席科學(xué)家Vincent Vanhoucke:機(jī)器人和深度學(xué)習(xí)正在發(fā)生一些“有趣的融合”| AAAI 2017
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見(jiàn)轉(zhuǎn)載須知。