0
本文作者: 三川 | 2017-02-19 20:55 |
編者按:上圖是 Yann LeCun 對(duì) GAN 的贊揚(yáng),意為“GAN 是機(jī)器學(xué)習(xí)過去 10 年發(fā)展中最有意思的想法。”
本文作者為前谷歌高級(jí)工程師、AI 初創(chuàng)公司 Wavefront 創(chuàng)始人兼 CTO Dev Nag,介紹了他是如何用不到五十行代碼,在 PyTorch 平臺(tái)上完成對(duì) GAN 的訓(xùn)練。雷鋒網(wǎng)編譯整理。
Dev Nag
在進(jìn)入技術(shù)層面之前,為照顧新入門的開發(fā)者,雷鋒網(wǎng)先來介紹下什么是 GAN。
2014 年,Ian Goodfellow 和他在蒙特利爾大學(xué)的同事發(fā)表了一篇震撼學(xué)界的論文。沒錯(cuò),我說的就是《Generative Adversarial Nets》,這標(biāo)志著生成對(duì)抗網(wǎng)絡(luò)(GAN)的誕生,而這是通過對(duì)計(jì)算圖和博弈論的創(chuàng)新性結(jié)合。他們的研究展示,給定充分的建模能力,兩個(gè)博弈模型能夠通過簡單的反向傳播(backpropagation)來協(xié)同訓(xùn)練。
這兩個(gè)模型的角色定位十分鮮明。給定真實(shí)數(shù)據(jù)集 R,G 是生成器(generator),它的任務(wù)是生成能以假亂真的假數(shù)據(jù);而 D 是判別器 (discriminator),它從真實(shí)數(shù)據(jù)集或者 G 那里獲取數(shù)據(jù), 然后做出判別真假的標(biāo)記。Ian Goodfellow 的比喻是,G 就像一個(gè)贗品作坊,想要讓做出來的東西盡可能接近真品,蒙混過關(guān)。而 D 就是文物鑒定專家,要能區(qū)分出真品和高仿(但在這個(gè)例子中,造假者 G 看不到原始數(shù)據(jù),而只有 D 的鑒定結(jié)果——前者是在盲干)。
理想情況下,D 和 G 都會(huì)隨著不斷訓(xùn)練,做得越來越好——直到 G 基本上成為了一個(gè)“贗品制造大師”,而 D 因無法正確區(qū)分兩種數(shù)據(jù)分布輸給 G。
實(shí)踐中,Ian Goodfellow 展示的這項(xiàng)技術(shù)在本質(zhì)上是:G 能夠?qū)υ紨?shù)據(jù)集進(jìn)行一種無監(jiān)督學(xué)習(xí),找到以更低維度的方式(lower-dimensional manner)來表示數(shù)據(jù)的某種方法。而無監(jiān)督學(xué)習(xí)之所以重要,就好像雷鋒網(wǎng)反復(fù)引用的 Yann LeCun 的那句話:“無監(jiān)督學(xué)習(xí)是蛋糕的糕體”。這句話中的蛋糕,指的是無數(shù)學(xué)者、開發(fā)者苦苦追尋的“真正的 AI”。
Dev Nag:在表面上,GAN 這門如此強(qiáng)大、復(fù)雜的技術(shù),看起來需要編寫天量的代碼來執(zhí)行,但事實(shí)未必如此。我們使用 PyTorch,能夠在 50 行代碼以內(nèi)創(chuàng)建出簡單的 GAN 模型。這之中,其實(shí)只有五個(gè)部分需要考慮:
R:原始、真實(shí)數(shù)據(jù)集
I:作為熵的一項(xiàng)來源,進(jìn)入生成器的隨機(jī)噪音
G:生成器,試圖模仿原始數(shù)據(jù)
D:判別器,試圖區(qū)別 G 的生成數(shù)據(jù)和 R
我們教 G 糊弄 D、教 D 當(dāng)心 G 的“訓(xùn)練”環(huán)。
1.) R:在我們的例子里,從最簡單的 R 著手——貝爾曲線(bell curve)。它把平均數(shù)(mean)和標(biāo)準(zhǔn)差(standard deviation)作為輸入,然后輸出能提供樣本數(shù)據(jù)正確圖形(從 Gaussian 用這些參數(shù)獲得 )的函數(shù)。在我們的代碼例子中,我們使用 4 的平均數(shù)和 1.25 的標(biāo)準(zhǔn)差。
2.) I:生成器的輸入是隨機(jī)的,為提高點(diǎn)難度,我們使用均勻分布(uniform distribution )而非標(biāo)準(zhǔn)分布。這意味著,我們的 Model G 不能簡單地改變輸入(放大/縮小、平移)來復(fù)制 R,而需要用非線性的方式來改造數(shù)據(jù)。
3.) G: 該生成器是個(gè)標(biāo)準(zhǔn)的前饋圖(feedforward graph)——兩層隱層,三個(gè)線性映射(linear maps)。我們使用了 ELU (exponential linear unit)。G 將從 I 獲得平均分布的數(shù)據(jù)樣本,然后找到某種方式來模仿 R 中標(biāo)準(zhǔn)分布的樣本。
4.) D: 判別器的代碼和 G 的生成器代碼很接近。一個(gè)有兩層隱層和三個(gè)線性映射的前饋圖。它會(huì)從 R 或 G 那里獲得樣本,然后輸出 0 或 1 的判別值,對(duì)應(yīng)反例和正例。這幾乎是神經(jīng)網(wǎng)絡(luò)的最弱版本了。
5.) 最后,訓(xùn)練環(huán)在兩個(gè)模式中變幻:第一步,用被準(zhǔn)確標(biāo)記的真實(shí)數(shù)據(jù) vs. 假數(shù)據(jù)訓(xùn)練 D;隨后,訓(xùn)練 G 來騙過 D,這里是用的不準(zhǔn)確標(biāo)記。道友們,這是正邪之間的較量。
即便你從沒接觸過 PyTorch,大概也能明白發(fā)生了什么。在第一部分(綠色),我們讓兩種類型的數(shù)據(jù)經(jīng)過 D,并對(duì) D 的猜測 vs. 真實(shí)標(biāo)記執(zhí)行不同的評(píng)判標(biāo)準(zhǔn)。這是 “forward” 那一步;隨后我們需要 “backward()” 來計(jì)算梯度,然后把這用來在 d_optimizer step() 中更新 D 的參數(shù)。這里,G 被使用但尚未被訓(xùn)練。
在最后的部分(紅色),我們對(duì) G 執(zhí)行同樣的操作——注意我們要讓 G 的輸出穿過 D (這其實(shí)是送給造假者一個(gè)鑒定專家來練手)。但在這一步,我們并不優(yōu)化、或者改變 D。我們不想讓鑒定者 D 學(xué)習(xí)到錯(cuò)誤的標(biāo)記。因此,我們只執(zhí)行 g_optimizer.step()。
這就完成了。據(jù)雷鋒網(wǎng)了解,還有一些其他的樣板代碼,但是對(duì)于 GAN 來說只需要這五個(gè)部分,沒有其他的了。
在 D 和 G 之間幾千輪交手之后,我們會(huì)得到什么?判別器 D 會(huì)快速改進(jìn),而 G 的進(jìn)展要緩慢許多。但當(dāng)模型達(dá)到一定性能之后,G 才有了個(gè)配得上的對(duì)手,并開始提升,巨幅提升。
兩萬輪訓(xùn)練之后,G 的輸入平均值超過 4,但會(huì)返回到相當(dāng)平穩(wěn)、合理的范圍(左圖)。同樣的,標(biāo)準(zhǔn)差一開始在錯(cuò)誤的方向降低,但隨后攀升至理想中的 1.25 區(qū)間(右圖),達(dá)到 R 的層次。
所以,基礎(chǔ)數(shù)據(jù)最終會(huì)與 R 吻合。那么,那些比 R 更高的時(shí)候呢?數(shù)據(jù)分布的形狀看起來合理嗎?畢竟,你一定可以得到有 4.0 的平均值和 1.25 標(biāo)準(zhǔn)差值的均勻分布,但那不會(huì)真的符合 R。我們一起來看看 G 生成的最終分布。
結(jié)果是不錯(cuò)的。左側(cè)的尾巴比右側(cè)長一些,但偏離程度和峰值與原始 Gaussian 十分相近。G 接近完美地再現(xiàn)了原始分布 R——D 落于下風(fēng),無法分辨真相和假相。而這就是我們想要得到的結(jié)果——使用不到 50 行代碼。
該說的都說完了,老司機(jī)請(qǐng)上 GitHub 把玩全套代碼。
地址:https://github.com/devnag/pytorch-generative-adversarial-networks
via medium
相關(guān)文章:
LS-GAN作者詮釋新型GAN:條條大路通羅馬,把GAN建立在Lipschitz密度上
GAN的理解與TensorFlow的實(shí)現(xiàn)
GAN學(xué)習(xí)指南:從原理入門到制作生成Demo,總共分幾步?
深度學(xué)習(xí)新星:GAN的基本原理、應(yīng)用和走向 | 雷鋒網(wǎng)公開課
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。