0
由于其對(duì)于原始數(shù)據(jù)潛在概率分布的強(qiáng)大感知能力,GAN 成為了當(dāng)下最熱門的生成模型之一。然而,訓(xùn)練不穩(wěn)定、調(diào)參難度大一直是困擾著 GAN 愛好者的老問題。本文是一份干貨滿滿的 GAN 訓(xùn)練心得,希望對(duì)有志從事該領(lǐng)域研究和工作的讀者有所幫助!
在當(dāng)下的深度學(xué)習(xí)研究領(lǐng)域中,對(duì)抗生成網(wǎng)絡(luò)(GAN)是最熱門的話題之一。在過去的幾個(gè)月里,關(guān)于 GAN 的論文數(shù)量呈井噴式增長(zhǎng)。GAN 已經(jīng)被應(yīng)廣泛應(yīng)用到了各種各樣的問題上,如果你之前對(duì)此并不太了解,可以通過下面的 Github 鏈接看到一些酷炫的 GAN 應(yīng)用:
時(shí)至今日,我已經(jīng)閱讀了大量有關(guān) GAN 的文獻(xiàn),但我還從來沒有自己動(dòng)手實(shí)踐過。因此,在瀏覽了一些對(duì)人有所啟發(fā)的論文和 Github 代碼倉(cāng)庫(kù)后,我決定親自嘗試訓(xùn)練一個(gè)簡(jiǎn)單的 GAN。不出所料,我立刻就遇到了一些問題。
本文的目標(biāo)讀者是從 GAN 入門的熱愛深度學(xué)習(xí)的朋友。除非你走了大運(yùn),否則你自己第一次訓(xùn)練一個(gè) GAN 的過程可能是非常令人沮喪的,而且需要花費(fèi)好幾個(gè)小時(shí)才能做好。當(dāng)然,隨著時(shí)間的推移和經(jīng)驗(yàn)的增長(zhǎng),你可能會(huì)漸漸善于訓(xùn)練 GAN。但是對(duì)于初學(xué)者來說,可能會(huì)犯一些錯(cuò),而且不知道該從哪里開始調(diào)試。在本文中,我想向大家分享我第一次從頭開始訓(xùn)練 GAN 時(shí)的觀察和經(jīng)驗(yàn)教訓(xùn),希望本文可以幫助大家節(jié)省幾個(gè)小時(shí)的調(diào)試時(shí)間。
在過去的一年左右的時(shí)間里,深度學(xué)習(xí)圈子里的每個(gè)人(甚至一些沒有參與過深度學(xué)習(xí)相關(guān)工作的人),都應(yīng)該對(duì) GAN 有所耳聞(除非你住在深山老林里、與世隔絕)。生成對(duì)抗網(wǎng)絡(luò)(GAN)是一種數(shù)據(jù)的生成式模型,主要以深度神經(jīng)網(wǎng)絡(luò)的形式存在。也就是說,給定一組訓(xùn)練數(shù)據(jù),GAN 可以學(xué)會(huì)估計(jì)數(shù)據(jù)的底層概率分布。這一點(diǎn)非常有用,因?yàn)槲覀儸F(xiàn)在可以根據(jù)學(xué)到的概率分布生成原始訓(xùn)練數(shù)據(jù)集中沒有出現(xiàn)過的樣本。如上面的鏈接所示,這催生了一些非常實(shí)用的應(yīng)用程序。
該領(lǐng)域的專家已經(jīng)提供了一些很棒的資源來解釋 GAN 以及它們的工作遠(yuǎn)離,所以本文在這里不會(huì)重復(fù)他們的工作。但是為了保持文章的完整性,在這里對(duì)相關(guān)概念進(jìn)行簡(jiǎn)要的回顧。
GAN 模型概覽
生成對(duì)抗網(wǎng)絡(luò)實(shí)際上是兩個(gè)相互競(jìng)爭(zhēng)的深度網(wǎng)絡(luò)。給定一個(gè)訓(xùn)練集 X(比如說幾千張貓的圖像),生成網(wǎng)絡(luò) G(x) 會(huì)將隨機(jī)向量作為輸入,并試圖生成與訓(xùn)練集中的圖像相類似的新圖像樣本。判別器網(wǎng)絡(luò) D(x) 則是一種二分類器,試圖將訓(xùn)練集 X 中「真實(shí)的」貓的圖像和由生成器生成的「假的」貓圖像區(qū)分開來。如此一來,生成網(wǎng)絡(luò)的職責(zé)就是學(xué)習(xí) X 中的數(shù)據(jù)的分布,這樣它就可以生成看起來真實(shí)的貓圖像,并確保判別器無法區(qū)分來自訓(xùn)練集的貓圖像和來自生成器的貓圖像。判別器則需要通過學(xué)習(xí)跟上生成器不斷進(jìn)化、嘗試通過新的方式生成可以「騙過」判別器的「假的」貓圖像的步伐。
最終,如果一切順利,生成器(或多或少)會(huì)學(xué)到訓(xùn)練數(shù)據(jù)的真實(shí)分布,并變得非常善于生成看起來真實(shí)的貓圖像。而判別器則不能再將訓(xùn)練集中的貓圖像和生成的貓圖像區(qū)分開來。
從這個(gè)意義上說,這兩個(gè)網(wǎng)絡(luò)一直在努力確保對(duì)方不能很好地完成自己的任務(wù)。那么,這究竟是如何起作用的呢?
另一種看待 GAN 的方式是:判別器試圖通過高速生成器真實(shí)的貓圖像看起來是怎樣的,從而引導(dǎo)生成器。最終,生成器研究清楚了問題,開始生成看起來真實(shí)的貓圖像。訓(xùn)練 GAN 的方法類似于博弈論中的極大極小算法,兩個(gè)網(wǎng)絡(luò)試圖達(dá)到同時(shí)考慮二者的納什均衡。更多細(xì)節(jié),請(qǐng)參閱本文底部給出的參考資料。
下面,我們將繼續(xù)分析 GAN 的訓(xùn)練過程。為了簡(jiǎn)單起見,我使用了「Keras+Tensorflow 后端」的組合,在 MNIST 數(shù)據(jù)集上訓(xùn)練了一個(gè) GAN(確切地說是 DC-GAN)。這并不太困難,在對(duì)生成器和判別器網(wǎng)絡(luò)進(jìn)行了一些小的調(diào)整之后,GAN 就可以生成清晰的 MNIST 圖像了。
生成的 MNIST 數(shù)字
如果你覺得 MNIST 中黑白數(shù)字沒那么有趣,那么生成各種物體和人的彩色圖片還很酷炫的。而這樣一來,問題就變得棘手了。在攻克了 MNIST 數(shù)據(jù)集之后,顯然下一步就是生成 CIFAR-10 圖像。經(jīng)過日復(fù)一日的超參數(shù)調(diào)參、改變網(wǎng)絡(luò)架構(gòu)、增添或刪除網(wǎng)絡(luò)層,我終于能夠生成出高質(zhì)量的和 CIFAR-10 類似的圖像。
使用 DC-GAN 生成的青蛙
使用 DC-GAN 生成的汽車
我最初使用了一個(gè)非常深的網(wǎng)絡(luò)(但是大多數(shù)情況下性能并不佳),最后使用的真正有效的網(wǎng)絡(luò)卻十分簡(jiǎn)單。在我開始調(diào)整網(wǎng)絡(luò)和訓(xùn)練過程時(shí),經(jīng)過 15 個(gè) epoch 的訓(xùn)練后生成的圖像從這樣:
變成了這樣:
最終的結(jié)果是:
下面,我基于自己犯過的錯(cuò)誤以及一直以來學(xué)到的東西,總結(jié)出了 7 大規(guī)避 GAN 訓(xùn)練陷阱的法則。所以,如果你是一個(gè) GAN 新兵,在訓(xùn)練中沒有很多成功的經(jīng)驗(yàn),也許看看下面的幾個(gè)方面可能會(huì)有所幫助:
鄭重聲明:下面我只是列舉出了我嘗試過的事情以及得到的結(jié)果。并且,我并不是說已經(jīng)解決了所有訓(xùn)練 GAN 的問題。
更大的卷積和可以覆蓋前一層特征圖中的更多像素,因此可以關(guān)注到更多的信息。在 CIFAR-10 數(shù)據(jù)集上,5*5 的卷積核可以取得很好的效果,而在判別器中使用 3*3 的卷積核會(huì)使判別器損失迅速趨近于 0。對(duì)于生成器來說,我們希望在頂層的卷積層中使用較大的卷積核來保持某種平滑性。而在較底層,我并沒有發(fā)現(xiàn)改變卷積核的大小會(huì)帶來任何關(guān)鍵的影響。
卷積核的數(shù)量的提升會(huì)大幅增加參數(shù)的數(shù)量,但通常我們確實(shí)需要更多的卷積核。我?guī)缀踉谒械木矸e層中都使用了 128 個(gè)卷積核。特別是在生成器中,使用較少的卷積核會(huì)使得最終生成的圖像太模糊。因此,似乎使用更多的卷積核有助于捕獲額外的信息,最終會(huì)提升生成圖像的清晰度。
盡管這一開始似乎有些奇怪,但是對(duì)我來說,改變標(biāo)簽的分配是一個(gè)重要的技巧。
如果你正在使用「真實(shí)圖像=1」、「生成圖像=0」的標(biāo)簽分配方法,將標(biāo)簽反轉(zhuǎn)過來會(huì)對(duì)訓(xùn)練有所幫助。正如我們會(huì)在后文中看到的,這有助于在迭代早期梯度流的傳播,也有助于訓(xùn)練的順利進(jìn)行。
這一點(diǎn)在訓(xùn)練判別器時(shí)極為重要。使用硬標(biāo)簽(非 1 即 0)幾乎會(huì)在早期就摧毀所有的學(xué)習(xí)進(jìn)程,導(dǎo)致判別器的損失迅速趨近于 0。我最終用一個(gè) 0-0.1 之間的隨機(jī)數(shù)來代表「標(biāo)簽 0」(真實(shí)圖像),并使用一個(gè) 0.9-1 之間的隨機(jī)數(shù)來代表 「標(biāo)簽 1」(生成圖像)。在訓(xùn)練生成器時(shí)則不用這樣做。
此外,添加一些帶噪聲的標(biāo)簽是有所幫助的。在我的實(shí)驗(yàn)過程中,我將輸入給判別器的圖像中的 5% 的標(biāo)簽隨機(jī)進(jìn)行了反轉(zhuǎn),即真實(shí)圖像被標(biāo)記為生成圖像、生成圖像被標(biāo)記為真實(shí)圖像。
批量歸一化當(dāng)然對(duì)提升最終的結(jié)果有所幫助。加入批量歸一化可以最終生成明顯更清晰的圖像。但是,如果你錯(cuò)誤地設(shè)置了卷積核的大小和數(shù)量,或者判別器損失迅速趨近于 0,那加入批量歸一化可能也無濟(jì)于事。
在網(wǎng)絡(luò)中加入批量歸一化(BN)層后生成的汽車
為了便于訓(xùn)練 GAN,確保輸入數(shù)據(jù)有類似的特性是很有用的。例如,與其在 CIFAR-10 數(shù)據(jù)集中所有 10 個(gè)類別上訓(xùn)練 GAN,不如選出一個(gè)類別(比如汽車或青蛙),訓(xùn)練 GAN 根據(jù)此類數(shù)據(jù)生成圖像。DCGAN 的另外一些變體可以很好地學(xué)會(huì)根據(jù)若干個(gè)類生成圖像。例如,條件 GAN(CGAN)將類別標(biāo)簽一同作為輸入,以類別標(biāo)簽為先驗(yàn)條件生成圖像。但是,如果你從一個(gè)基礎(chǔ)的 DCGAN 開始學(xué)習(xí)訓(xùn)練 GAN,最好保持模型簡(jiǎn)單。
如果可能的話,請(qǐng)監(jiān)控網(wǎng)絡(luò)中的梯度和損失變化。這可以幫助我們了解訓(xùn)練的進(jìn)展情況。如果訓(xùn)練進(jìn)展不是很順利的話,這甚至可以幫助我們進(jìn)行調(diào)試。
理想情況下,生成器應(yīng)該在訓(xùn)練的早期接受大梯度,因?yàn)樗枰獙W(xué)會(huì)如何生成看起來真實(shí)的數(shù)據(jù)。另一方面,判別器則在訓(xùn)練早期則不應(yīng)該總是接受大梯度,因?yàn)樗梢院苋菀椎貐^(qū)分真實(shí)圖像和生成圖像。當(dāng)生成器訓(xùn)練地足夠好時(shí),判別器就沒有那么容易區(qū)分真實(shí)圖像和生成圖像了。它會(huì)不斷發(fā)生錯(cuò)誤,并得到較大的梯度。
我在 CIFAR-10 中的汽車上訓(xùn)練的幾個(gè)早期版本的 GAN 有許多卷積層和批量歸一化層,并且沒有進(jìn)行標(biāo)簽反轉(zhuǎn)。除了監(jiān)控梯度的變化趨勢(shì),監(jiān)控梯度的大小也很重要。如果生成器中網(wǎng)絡(luò)層的梯度太小,學(xué)習(xí)可能會(huì)很慢或者根本不會(huì)進(jìn)行學(xué)習(xí)。
生成器頂層的梯度(x 軸:minibatch 迭代次數(shù))
生成器底層的梯度(x 軸:minibatch 迭代次數(shù))
判別器頂層的梯度(x 軸:minibatch 迭代次數(shù))
判別器底層的梯度(x 軸:minibatch 迭代次數(shù))
生成器最底層的梯度太小,無法進(jìn)行任何的學(xué)習(xí)。判別器的梯度自始至終都沒有變化,說明判別器并沒有真正學(xué)到任何東西?,F(xiàn)在,讓我們將其與帶有上述所有改進(jìn)方案的 GAN 的梯度進(jìn)行對(duì)比,改進(jìn)后的 GAN 得到了很好的、與真實(shí)圖像看起來類似的圖像:
生成器頂層的梯度(x 軸:minibatch 迭代次數(shù))
生成器底層的梯度(x 軸:minibatch 迭代次數(shù))
判別器頂層的梯度(x 軸:minibatch 迭代次數(shù))
判別器底層的梯度(x 軸:minibatch 迭代次數(shù))
此時(shí)生成器底層的梯度明顯要高于之前版本的 GAN。此外,隨著訓(xùn)練的進(jìn)展,梯度流的變化趨勢(shì)與預(yù)期一樣:生成器在訓(xùn)練早期梯度較大,而一旦生成器被訓(xùn)練得足夠好,判別器的頂層就會(huì)維持高的梯度。
可能是由于我缺乏耐心,我犯了一個(gè)愚蠢的錯(cuò)誤——在進(jìn)行了幾百個(gè) minibatch 的訓(xùn)練后,當(dāng)我看到損失函數(shù)仍然沒有任何明顯的下降,生成的樣本仍然充滿噪聲時(shí),我終止了訓(xùn)練。比起等到訓(xùn)練結(jié)束才意識(shí)到網(wǎng)絡(luò)什么都沒有學(xué)到,重新開始工作、節(jié)省時(shí)間確實(shí)讓人心動(dòng)。GAN 的訓(xùn)練時(shí)間很長(zhǎng),初始的少量的損失值和生成的樣本幾乎不能顯示出任何趨勢(shì)和進(jìn)展。在結(jié)束訓(xùn)練過程并調(diào)整設(shè)置之前,還是很有必要等待一段時(shí)間的。
這條規(guī)則的一個(gè)例外情況是:如果你看到判別器損失迅速趨近于 0。如果發(fā)生了這種情況,幾乎就沒有任何機(jī)會(huì)補(bǔ)救了。最好在對(duì)網(wǎng)絡(luò)或訓(xùn)練過程進(jìn)行調(diào)整后重新開始訓(xùn)練。
最終的 GAN 的架構(gòu)如下所示:
希望本文中的這些建議可以幫助所有人從頭開始訓(xùn)練他們的第一個(gè) DC-GAN。下面,本文將給出一些包含大量關(guān)于 GAN 的信息的學(xué)習(xí)資源:
GAN 論文參考:
「Generative Adversarial Networks」
「Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks」
「Improved Techniques for Training GANs」
其他參考鏈接:
「Training GANs: Better understanding and other improved techniques」
「NIPS 2016 GAN 教程」
「Conditional GAN」
本文最終版 GAN 的 Keras 代碼鏈接如下:
https://github.com/utkd/gans/blob/master/cifar10dcgan.ipynb?source=post_page
via https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9 雷鋒網(wǎng)雷鋒網(wǎng)雷鋒網(wǎng)
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。