0
GAN 從 2014 年誕生以來(lái)發(fā)展的是相當(dāng)火熱,比較著名的 GAN 的應(yīng)用有 Pix2Pix、CycleGAN 等。本篇文章主要是讓初學(xué)者通過(guò)代碼了解 GAN 的結(jié)構(gòu)和運(yùn)作機(jī)制,對(duì)理論細(xì)節(jié)不做過(guò)多介紹。我們還是采用 MNIST 手寫數(shù)據(jù)集(不得不說(shuō)這個(gè)數(shù)據(jù)集對(duì)于新手來(lái)說(shuō)非常好用)來(lái)作為我們的訓(xùn)練數(shù)據(jù),我們將構(gòu)建一個(gè)簡(jiǎn)單的 GAN 來(lái)進(jìn)行手寫數(shù)字圖像的生成。
GAN 主要包括了兩個(gè)部分,即生成器 generator 與判別器 discriminator。生成器主要用來(lái)學(xué)習(xí)真實(shí)圖像分布從而讓自身生成的圖像更加真實(shí),以騙過(guò)判別器。判別器則需要對(duì)接收的圖片進(jìn)行真假判別。在整個(gè)過(guò)程中,生成器努力地讓生成的圖像更加真實(shí),而判別器則努力地去識(shí)別出圖像的真假,這個(gè)過(guò)程相當(dāng)于一個(gè)二人博弈,隨著時(shí)間的推移,生成器和判別器在不斷地進(jìn)行對(duì)抗,最終兩個(gè)網(wǎng)絡(luò)達(dá)到了一個(gè)動(dòng)態(tài)均衡:生成器生成的圖像接近于真實(shí)圖像分布,而判別器識(shí)別不出真假圖像,對(duì)于給定圖像的預(yù)測(cè)為真的概率基本接近 0.5(相當(dāng)于隨機(jī)猜測(cè)類別)。
對(duì)于 GAN 更加直觀的理解可以用一個(gè)例子來(lái)說(shuō)明:造假幣的團(tuán)伙相當(dāng)于生成器,他們想通過(guò)偽造金錢來(lái)騙過(guò)銀行,使得假幣能夠正常交易,而銀行相當(dāng)于判別器,需要判斷進(jìn)來(lái)的錢是真錢還是假幣。因此假幣團(tuán)伙的目的是要造出銀行識(shí)別不出的假幣而騙過(guò)銀行,銀行則是要想辦法準(zhǔn)確地識(shí)別出假幣。
因此,我們可以將上面的內(nèi)容進(jìn)行一個(gè)總結(jié)。給定真 = 1,假 = 0,那么有:
對(duì)于給定的真實(shí)圖片(real image),判別器要為其打上標(biāo)簽 1;
對(duì)于給定的生成圖片(fake image),判別器要為其打上標(biāo)簽 0;
對(duì)于生成器傳給辨別器的生成圖片,生成器希望辨別器打上標(biāo)簽 1。
有了上面的直觀理解,下面就讓我們來(lái)實(shí)現(xiàn)一個(gè) GAN 來(lái)生成手寫數(shù)據(jù)吧!還有一些細(xì)節(jié)會(huì)在代碼部分進(jìn)行介紹。
TensorFlow 1.0
Python 3
Jupyter Notebook
GitHub 地址:NELSONZHAO/zhihu
建議將代碼 pull 下來(lái),有部分代碼實(shí)現(xiàn)沒(méi)有寫在文章中。
數(shù)據(jù)加載與查看
數(shù)據(jù)我們使用 TensorFlow 中給定的 MNIST 數(shù)據(jù)接口。
在構(gòu)建模型之前,我們首先來(lái)看一下我們需要完成的任務(wù):
Inputs
generator
discriminator
定義參數(shù)
loss & optimizer
訓(xùn)練模型
顯示結(jié)果
輸入 inputs
輸入函數(shù)主要來(lái)定義真實(shí)圖片與生成圖片兩個(gè) tensor。
定義生成器
我們的生成器結(jié)構(gòu)如下:
我們使用了一個(gè)采用 Leaky ReLU 作為激活函數(shù)的隱層,并在輸出層加入 tanh 激活函數(shù)。
下面是生成器的代碼。注意在定義生成器和判別器時(shí),我們要指定變量的 scope,這是因?yàn)?GAN 中實(shí)際上包含生成器與辨別器兩個(gè)網(wǎng)絡(luò),在后面進(jìn)行訓(xùn)練時(shí)是分開(kāi)訓(xùn)練的,因此我們要把 scope 定義好,方便訓(xùn)練時(shí)候指定變量。
在這個(gè)網(wǎng)絡(luò)中,我們使用了一個(gè)隱層,并加入 dropout 防止過(guò)擬合。通過(guò)輸入噪聲圖片,generator 輸出一個(gè)與真實(shí)圖片一樣大小的圖像。
在這里我們的隱層激活函數(shù)采用的是 Leaky ReLU(中文不知道咋翻譯),這個(gè)函數(shù)在 ReLU 函數(shù)基礎(chǔ)上改變了左半邊的定義。
圖片來(lái)自維基百科。Andrej Karpathy 在 CS231n 中也提到有模型通過(guò)這個(gè)函數(shù)取得了不錯(cuò)的效果。
由于 TensorFlow 中沒(méi)有這個(gè)函數(shù)的實(shí)現(xiàn),在這里我們通過(guò)函數(shù)定義實(shí)現(xiàn)了 Leaky ReLU,其中 alpha 是一個(gè)很小的數(shù)。在輸出層我們使用 tanh 函數(shù),這是因?yàn)?tanh 在這里相比 sigmoid 的結(jié)果會(huì)更好一點(diǎn)(在這里要注意,由于生成器的生成圖片像素限制在了 (-1, 1) 的取值之間,而 MNIST 數(shù)據(jù)集的像素區(qū)間為 [0, 1],所以在訓(xùn)練時(shí)要對(duì) MNIST 的輸入做處理,具體見(jiàn)訓(xùn)練部分的代碼)。到此,我們構(gòu)建好了生成器,它通過(guò)接收一個(gè)噪聲圖片輸出一個(gè)與真實(shí)圖片一樣 size 的圖像。
定義判別器
判別器的結(jié)構(gòu)如下:
判別器接收一張圖片,并判斷它的真假,同樣隱層使用了 Leaky ReLU,輸出層為 1 個(gè)結(jié)點(diǎn),輸出為 1 的概率。代碼如下:
在這里,我們需要注意真實(shí)圖片與生成圖片是共享判別器的參數(shù)的,因此在這里我們留了 reuse 接口來(lái)方便我們后面調(diào)用。
定義參數(shù)
img_size 是我們真實(shí)圖片的 size=32*32=784。
smooth 是進(jìn)行 Label Smoothing Regularization 的參數(shù),在后面會(huì)介紹。
構(gòu)建網(wǎng)絡(luò)
接下來(lái)我們來(lái)構(gòu)建我們的網(wǎng)絡(luò),并獲得生成器與判別器返回的變量。
我們分別獲得了生成器與判別器的 logits 和 outputs。注意真實(shí)圖片與生成圖片是共享參數(shù)的,因此在判別器輸入生成圖片時(shí),需要 reuse 參數(shù)。
定義 Loss 和 Optimizer
有了上面的 logits,我們就可以定義我們的 loss 和 Optimizer。在這之前,我們?cè)賮?lái)回顧一下生成器和判別器各自的目的是什么:
對(duì)于給定的真實(shí)圖片,辨別器要為其打上標(biāo)簽 1;
對(duì)于給定的生成圖片,辨別器要為其打上標(biāo)簽 0;
對(duì)于生成器傳給辨別器的生成圖片,生成器希望辨別器打上標(biāo)簽 1。
我們來(lái)把上面這三句話轉(zhuǎn)換成代碼:
d_loss_real 對(duì)應(yīng)著真實(shí)圖片的 loss,它盡可能讓判別器的輸出接近于 1。在這里,我們使用了單邊的 Label Smoothing Regularization,它是一種防止過(guò)擬合的方式,在傳統(tǒng)的分類中,我們的目標(biāo)非 0 即 1,從直覺(jué)上來(lái)理解的話,這樣的目標(biāo)不夠 soft,會(huì)導(dǎo)致訓(xùn)練出的模型對(duì)于自己的預(yù)測(cè)結(jié)果過(guò)于自信。因此我們加入一個(gè)平滑值來(lái)讓判別器的泛化效果更好。
d_loss_fake 對(duì)應(yīng)著生成圖片的 loss,它盡可能地讓判別器輸出為 0。
d_loss_real 與 d_loss_fake 加起來(lái)就是整個(gè)判別器的損失。
而在生成器端,它希望讓判別器對(duì)自己生成的圖片盡可能輸出為 1,相當(dāng)于它在于判別器進(jìn)行對(duì)抗。
下面我們定義了優(yōu)化函數(shù),由于 GAN 中包含了生成器和判別器兩個(gè)網(wǎng)絡(luò),因此需要分開(kāi)進(jìn)行優(yōu)化,這也是我們?cè)谥岸x variable_scope 的原因。
訓(xùn)練模型
由于訓(xùn)練部分代碼太長(zhǎng),我在這里就不貼出來(lái)了,請(qǐng)前往我的 GitHub 下載代碼。在訓(xùn)練部分,我們記錄了部分圖像的生成過(guò)程,并記錄了訓(xùn)練數(shù)據(jù)的 loss 變化。
我們將整個(gè)訓(xùn)練過(guò)程的 loss 變化繪制出來(lái):
從圖中可以看出來(lái),最終的判別器總體 loss 在 1 左右波動(dòng),而 real loss 和 fake loss 幾乎在一條水平線上波動(dòng),這說(shuō)明判別器最終對(duì)于真假圖像已經(jīng)沒(méi)有判別能力,而是進(jìn)行隨機(jī)判斷。
查看過(guò)程結(jié)果
我們?cè)谡麄€(gè)訓(xùn)練過(guò)程中記錄了 25 個(gè)樣本在不同階段的 samples 圖像,以序列化的方式進(jìn)行了保存,我們的將 samples 加載進(jìn)來(lái)。samples 的 size=epochs x 2 x n_samples x 784,我們的迭代次數(shù)為 300 輪,25 個(gè)樣本,因此,samples 的 size=300 x 2 x 25 x 784。我們將最后一輪的生成結(jié)果打印出來(lái):
這就是我們的 GAN 通過(guò)學(xué)習(xí)真實(shí)圖片的分布后生成的圖像結(jié)果。
那么有同學(xué)可能會(huì)問(wèn)了,我們?nèi)绻胍催@ 300 輪中生成圖像的變化是什么樣該怎么辦呢?因?yàn)槲覀円呀?jīng)有了 samples,存儲(chǔ)了每一輪迭代的結(jié)果,我們可以挑選幾次迭代,把對(duì)應(yīng)的圖像打出來(lái):
這里我挑選了第 0, 5, 10, 20, 40, 60, 80, 100, 150, 250 輪的迭代效果圖,在這個(gè)圖中,我們可以看到最開(kāi)始的時(shí)候只有中間是白色,背景黑色塊中存在著很多噪聲。隨著迭代次數(shù)的不斷增加,生成器制造 “假圖” 的能力也越來(lái)越強(qiáng),它逐漸學(xué)得了真實(shí)圖片的分布,最明顯的一點(diǎn)就是圖片區(qū)分出了黑色背景和白色字符的界限。
生成新的圖片
如果我們想重新生成新的圖片呢?此時(shí)我們只需要將我們之前保存好的模型文件加載進(jìn)來(lái)就可以啦。
整篇文章基于 MNIST 數(shù)據(jù)集構(gòu)造了一個(gè)簡(jiǎn)單的 GAN 模型,相信小伙伴看完代碼會(huì)對(duì) GAN 有一個(gè)初步的了解。從最終的模型結(jié)果來(lái)看,生成的圖像能夠?qū)⒈尘芭c數(shù)字區(qū)分開(kāi),黑色塊噪聲逐漸消失,但從顯示結(jié)果來(lái)看還是有很多模糊區(qū)域的。
對(duì)于這里的圖片處理,相信很多小伙伴會(huì)想到卷積神經(jīng)網(wǎng)絡(luò),那么后面我們還會(huì)將生成器和判別器改為卷積神經(jīng)網(wǎng)絡(luò)來(lái)構(gòu)造深度卷積 GAN,它對(duì)于圖片的生成會(huì)取得更好的效果。
如果覺(jué)得不錯(cuò),請(qǐng)給 GitHub 點(diǎn)個(gè) Star 吧~
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見(jiàn)轉(zhuǎn)載須知。