0
本文作者: 奕欣 | 2017-04-24 15:17 |
雷鋒網(wǎng)按:本文作者為中山大學(xué)鄭華濱,他在知乎的提問《生成式對(duì)抗網(wǎng)絡(luò)GAN有哪些最新的發(fā)展,可以實(shí)際應(yīng)用到哪些場(chǎng)景中?》中做了回答,介紹了Wasserstein GAN的最新進(jìn)展。本文為鄭華濱基于此回答向雷鋒網(wǎng)供稿,未經(jīng)許可不得轉(zhuǎn)載。
前段時(shí)間,Wasserstein GAN以其精巧的理論分析、簡(jiǎn)單至極的算法實(shí)現(xiàn)、出色的實(shí)驗(yàn)效果,在GAN研究圈內(nèi)掀起了一陣熱潮(對(duì)WGAN不熟悉的讀者,可以參考我之前寫的介紹文章:令人拍案叫絕的Wasserstein GAN - 知乎專欄)。但是很多人(包括我們實(shí)驗(yàn)室的同學(xué))到了上手跑實(shí)驗(yàn)的時(shí)候,卻發(fā)現(xiàn)WGAN實(shí)際上沒那么完美,反而存在著訓(xùn)練困難、收斂速度慢等問題。其實(shí),WGAN的作者M(jìn)artin Arjovsky不久后就在reddit上表示他也意識(shí)到了這個(gè)問題,認(rèn)為關(guān)鍵在于原設(shè)計(jì)中Lipschitz限制的施加方式不對(duì),并在新論文中提出了相應(yīng)的改進(jìn)方案:
論文:Improved Training of Wasserstein GANs
Tensorflow實(shí)現(xiàn):igul222/improved_wgan_training
首先回顧一下WGAN的關(guān)鍵部分——Lipschitz限制是什么。WGAN中,判別器D和生成器G的loss函數(shù)分別是:
(公式1)
(公式2)
公式1表示判別器希望盡可能拉高真樣本的分?jǐn)?shù),拉低假樣本的分?jǐn)?shù),公式2表示生成器希望盡可能拉高假樣本的分?jǐn)?shù)。
Lipschitz限制則體現(xiàn)為,在整個(gè)樣本空間上,要求判別器函數(shù)
梯度的Lp-norm不大于一個(gè)有限的常數(shù)
:
(公式3)
直觀上解釋,就是當(dāng)輸入的樣本稍微變化后,判別器給出的分?jǐn)?shù)不能發(fā)生太過劇烈的變化。在原來的論文中,這個(gè)限制具體是通過weight clipping的方式實(shí)現(xiàn)的:每當(dāng)更新完一次判別器的參數(shù)之后,就檢查判別器的所有參數(shù)的絕對(duì)值有沒有超過一個(gè)閾值,比如0.01,有的話就把這些參數(shù)clip回 [-0.01, 0.01] 范圍內(nèi)。通過在訓(xùn)練過程中保證判別器的所有參數(shù)有界,就保證了判別器不能對(duì)兩個(gè)略微不同的樣本給出天差地別的分?jǐn)?shù)值,從而間接實(shí)現(xiàn)了Lipschitz限制。
然而weight clipping的實(shí)現(xiàn)方式存在兩個(gè)嚴(yán)重問題:
第一,如公式1所言,判別器loss希望盡可能拉大真假樣本的分?jǐn)?shù)差,然而weight clipping獨(dú)立地限制每一個(gè)網(wǎng)絡(luò)參數(shù)的取值范圍,在這種情況下我們可以想象,最優(yōu)的策略就是盡可能讓所有參數(shù)走極端,要么取最大值(如0.001)要么取最小值(如-0.001)!為了驗(yàn)證這一點(diǎn),作者統(tǒng)計(jì)了經(jīng)過充分訓(xùn)練的判別器中所有網(wǎng)絡(luò)參數(shù)的數(shù)值分布,發(fā)現(xiàn)真的集中在最大和最小兩個(gè)極端上:
這樣帶來的結(jié)果就是,判別器會(huì)非常傾向于學(xué)習(xí)一個(gè)簡(jiǎn)單的映射函數(shù)(想想看,幾乎所有參數(shù)都是正負(fù)0.01,都已經(jīng)可以直接視為一個(gè)二值化神經(jīng)網(wǎng)絡(luò)了,太簡(jiǎn)單了)。而作為一個(gè)深層神經(jīng)網(wǎng)絡(luò)來說,這實(shí)在是對(duì)自身強(qiáng)大擬合能力的巨大浪費(fèi)!判別器沒能充分利用自身的模型能力,經(jīng)過它回傳給生成器的梯度也會(huì)跟著變差。
在正式介紹gradient penalty之前,我們可以先看看在它的指導(dǎo)下,同樣充分訓(xùn)練判別器之后,參數(shù)的數(shù)值分布就合理得多了,判別器也能夠充分利用自身模型的擬合能力:
第二個(gè)問題,weight clipping會(huì)導(dǎo)致很容易一不小心就梯度消失或者梯度爆炸。原因是判別器是一個(gè)多層網(wǎng)絡(luò),如果我們把clipping threshold設(shè)得稍微小了一點(diǎn),每經(jīng)過一層網(wǎng)絡(luò),梯度就變小一點(diǎn)點(diǎn),多層之后就會(huì)指數(shù)衰減;反之,如果設(shè)得稍微大了一點(diǎn),每經(jīng)過一層網(wǎng)絡(luò),梯度變大一點(diǎn)點(diǎn),多層之后就會(huì)指數(shù)爆炸。只有設(shè)得不大不小,才能讓生成器獲得恰到好處的回傳梯度,然而在實(shí)際應(yīng)用中這個(gè)平衡區(qū)域可能很狹窄,就會(huì)給調(diào)參工作帶來麻煩。相比之下,gradient penalty就可以讓梯度在后向傳播的過程中保持平穩(wěn)。論文通過下圖體現(xiàn)了這一點(diǎn),其中橫軸代表判別器從低到高第幾層,縱軸代表梯度回傳到這一層之后的尺度大?。ㄗ⒁饪v軸是對(duì)數(shù)刻度),c是clipping threshold:
說了這么多,gradient penalty到底是什么?
前面提到,Lipschitz限制是要求判別器的梯度不超過K,那我們何不直接設(shè)置一個(gè)額外的loss項(xiàng)來體現(xiàn)這一點(diǎn)呢?比如說:
(公式4)
不過,既然判別器希望盡可能拉大真假樣本的分?jǐn)?shù)差距,那自然是希望梯度越大越好,變化幅度越大越好,所以判別器在充分訓(xùn)練之后,其梯度norm其實(shí)就會(huì)是在K附近。知道了這一點(diǎn),我們可以把上面的loss改成要求梯度norm離K越近越好,效果是類似的:
(公式5)
究竟是公式4好還是公式5好,我看不出來,可能需要實(shí)驗(yàn)驗(yàn)證,反正論文作者選的是公式5。接著我們簡(jiǎn)單地把K定為1,再跟WGAN原來的判別器loss加權(quán)合并,就得到新的判別器loss:
(公式6)
這就是所謂的gradient penalty了嗎?還沒完。公式6有兩個(gè)問題,首先是loss函數(shù)中存在梯度項(xiàng),那么優(yōu)化這個(gè)loss豈不是要算梯度的梯度?一些讀者可能對(duì)此存在疑惑,不過這屬于實(shí)現(xiàn)上的問題,放到后面說。
其次,3個(gè)loss項(xiàng)都是期望的形式,落到實(shí)現(xiàn)上肯定得變成采樣的形式。前面兩個(gè)期望的采樣我們都熟悉,第一個(gè)期望是從真樣本集里面采,第二個(gè)期望是從生成器的噪聲輸入分布采樣后,再由生成器映射到樣本空間??墒堑谌齻€(gè)分布要求我們?cè)谡麄€(gè)樣本空間上采樣,這完全不科學(xué)!由于所謂的維度災(zāi)難問題,如果要通過采樣的方式在圖片或自然語(yǔ)言這樣的高維樣本空間中估計(jì)期望值,所需樣本量是指數(shù)級(jí)的,實(shí)際上沒法做到。
所以,論文作者就非常機(jī)智地提出,我們其實(shí)沒必要在整個(gè)樣本空間上施加Lipschitz限制,只要重點(diǎn)抓住生成樣本集中區(qū)域、真實(shí)樣本集中區(qū)域以及夾在它們中間的區(qū)域就行了。具體來說,我們先隨機(jī)采一對(duì)真假樣本,還有一個(gè)0-1的隨機(jī)數(shù):
(公式7)
然后在和
的連線上隨機(jī)插值采樣:
(公式8)
把按照上述流程采樣得到的所滿足的分布記為
,就得到最終版本的判別器loss:
(公式9)
這就是新論文所采用的gradient penalty方法,相應(yīng)的新WGAN模型簡(jiǎn)稱為WGAN-GP。我們可以做一個(gè)對(duì)比:
weight clipping是對(duì)樣本空間全局生效,但因?yàn)槭情g接限制判別器的梯度norm,會(huì)導(dǎo)致一不小心就梯度消失或者梯度爆炸;
gradient penalty只對(duì)真假樣本集中區(qū)域、及其中間的過渡地帶生效,但因?yàn)槭侵苯影雅袆e器的梯度norm限制在1附近,所以梯度可控性非常強(qiáng),容易調(diào)整到合適的尺度大小。
論文還講了一些使用gradient penalty時(shí)需要注意的配套事項(xiàng),這里只提一點(diǎn):由于我們是對(duì)每個(gè)樣本獨(dú)立地施加梯度懲罰,所以判別器的模型架構(gòu)中不能使用Batch Normalization,因?yàn)樗鼤?huì)引入同個(gè)batch中不同樣本的相互依賴關(guān)系。如果需要的話,可以選擇其他normalization方法,如layer normalization、
weight normalization和instance normalization,這些方法就不會(huì)引入樣本之間的依賴。論文推薦的是layer normalization。
實(shí)驗(yàn)表明,gradient penalty能夠顯著提高訓(xùn)練速度,解決了原始WGAN收斂緩慢的問題:
雖然還是比不過DCGAN,但是因?yàn)閃GAN不存在平衡判別器與生成器的問題,所以會(huì)比DCGAN更穩(wěn)定,還是很有優(yōu)勢(shì)的。不過,作者憑什么能這么說?因?yàn)橄旅娴膶?shí)驗(yàn)體現(xiàn)出,在各種不同的網(wǎng)絡(luò)架構(gòu)下,其他GAN變種能不能訓(xùn)練好是有點(diǎn)看運(yùn)氣的事情,但是WGAN-GP全都能夠訓(xùn)練好,尤其是最下面一行所對(duì)應(yīng)的101層殘差神經(jīng)網(wǎng)絡(luò):
剩下的實(shí)驗(yàn)結(jié)果中,比較厲害的是第一次成功做到了“純粹的”的文本GAN訓(xùn)練!我們知道在圖像上訓(xùn)練GAN是不需要額外的有監(jiān)督信息的,但是之前就沒有人能夠像訓(xùn)練圖像GAN一樣訓(xùn)練好一個(gè)文本GAN,要么依賴于預(yù)訓(xùn)練一個(gè)語(yǔ)言模型,要么就是利用已有的有監(jiān)督ground truth提供指導(dǎo)信息。而現(xiàn)在WGAN-GP終于在無需任何有監(jiān)督信息的情況下,生成出下圖所示的英文字符序列:
它是怎么做到的呢?我認(rèn)為關(guān)鍵之處是對(duì)樣本形式的更改。以前我們一般會(huì)把文本這樣的離散序列樣本表示為sequence of index,但是它把文本表示成sequence of probability vector。對(duì)于生成樣本來說,我們可以取網(wǎng)絡(luò)softmax層輸出的詞典概率分布向量,作為序列中每一個(gè)位置的內(nèi)容;而對(duì)于真實(shí)樣本來說,每個(gè)probability vector實(shí)際上就蛻化為我們熟悉的onehot vector。
但是如果按照傳統(tǒng)GAN的思路來分析,這不是作死嗎?一邊是hard onehot vector,另一邊是soft probability vector,判別器一下子就能夠區(qū)分它們,生成器還怎么學(xué)習(xí)?沒關(guān)系,對(duì)于WGAN來說,真假樣本好不好區(qū)分并不是問題,WGAN只是拉近兩個(gè)分布之間的Wasserstein距離,就算是一邊是hard onehot另一邊是soft probability也可以拉近,在訓(xùn)練過程中,概率向量中的有些項(xiàng)可能會(huì)慢慢變成0.8、0.9到接近1,整個(gè)向量也會(huì)接近onehot,最后我們要真正輸出sequence of index形式的樣本時(shí),只需要對(duì)這些概率向量取argmax得到最大概率的index就行了。
新的樣本表示形式+WGAN的分布拉近能力是一個(gè)“黃金組合”,但除此之外,還有其他因素幫助論文作者跑出上圖的效果,包括:
文本粒度為英文字符,而非英文單詞,所以字典大小才二三十,大大減小了搜索空間
文本長(zhǎng)度也才32
生成器用的不是常見的LSTM架構(gòu),而是多層反卷積網(wǎng)絡(luò),輸入一個(gè)高斯噪聲向量,直接一次性轉(zhuǎn)換出所有32個(gè)字符
上面第三點(diǎn)非常有趣,因?yàn)樗屛衣?lián)想到前段時(shí)間挺火的語(yǔ)言學(xué)科幻電影《降臨》:
里面的外星人“七肢怪”所使用的語(yǔ)言跟人類不同,人類使用的是線性的、串行的語(yǔ)言,而“七肢怪”使用的是非線性的、并行的語(yǔ)言。“七肢怪”在跟主角交流的時(shí)候,都是一次性同時(shí)給出所有的語(yǔ)義單元的,所以說它們其實(shí)是一些多層反卷積網(wǎng)絡(luò)進(jìn)化出來的人工智能生命嗎?
開完腦洞,我們回過頭看,不得不承認(rèn)這個(gè)實(shí)驗(yàn)的setup實(shí)在過于簡(jiǎn)化了,能否擴(kuò)展到更加實(shí)際的復(fù)雜場(chǎng)景,也會(huì)是一個(gè)問題。但是不管怎樣,生成出來的結(jié)果仍然是突破性的。
最后說回gradient penalty的實(shí)現(xiàn)問題。loss中本身包含梯度,優(yōu)化loss就需要求梯度的梯度,這個(gè)功能并不是現(xiàn)在所有深度學(xué)習(xí)框架的標(biāo)配功能,不過好在Tensorflow就有提供這個(gè)接口——tf.gradients。開頭鏈接的GitHub源碼中就是這么寫的:
# interpolates就是隨機(jī)插值采樣得到的圖像 gradients = tf.gradients(Discriminator(interpolates), [ interpolates])[0]
對(duì)于我這樣的PyTorch黨就非常不幸了,高階梯度的功能還在開發(fā),感興趣的PyTorch黨可以訂閱這個(gè)GitHub的pull request:Autograd refactor,如果它被merged了話就可以在最新版中使用高階梯度的功能實(shí)現(xiàn)gradient penalty了。
但是除了等待我們就沒有別的辦法了嗎?其實(shí)可能是有的,我想到了一種近似方法來實(shí)現(xiàn)gradient penalty,只需要把微分換成差分:
(公式10)
也就是說,我們?nèi)匀皇窃诜植?img alt="掀起熱潮的Wasserstein GAN,在近段時(shí)間又有哪些研究進(jìn)展?" src="https://static.leiphone.com/uploads/new/article/740_740/201704/58fd6a9c41e06.png?imageMogr2/quality/90"/>上隨機(jī)采樣,但是一次采兩個(gè),然后要求它們的連線斜率要接近1,這樣理論上也可以起到跟公式9一樣的效果,我自己在MNIST+MLP上簡(jiǎn)單驗(yàn)證過有作用,PyTorch黨甚至Tensorflow黨都可以嘗試用一下。
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。