1
雷鋒網(wǎng) AI 科技評論按:OpenAI 今天發(fā)表了一篇博客介紹了自己新設(shè)計的元學(xué)習(xí)算法「Reptile」。算法的結(jié)構(gòu)簡單,但卻可以同時兼顧單個樣本和大規(guī)模樣本的精確學(xué)習(xí)。OpenAI 甚至還在博客頁面上做了一個互動界面,可以直接在四個方框里畫出訓(xùn)練樣本和要分類的樣本,算法可以立即學(xué)習(xí)、實時更新分類結(jié)果。
用 Reptile 實時小樣本學(xué)習(xí),分類手繪圖案。訓(xùn)練數(shù)據(jù)和要分類的圖案都可以任意繪制。歡迎到博客頁面 https://blog.openai.com/reptile/ 自行嘗試一下。
根據(jù) OpenAI 的介紹,這個新的元學(xué)習(xí)(meta-learning)算法 Reptile 的運作原理是反復(fù)對任務(wù)采樣、在其上運用梯度下降,并從初始參數(shù)開始持續(xù)地向著任務(wù)上學(xué)到的參數(shù)更新。Reptile 可以和應(yīng)用廣泛的元學(xué)習(xí)算法 MAML (model-agnostic meta-learning)達到同樣的表現(xiàn),同時還更易于實現(xiàn)、計算效率更高。雷鋒網(wǎng) AI 科技評論把這篇介紹博客全文翻譯如下。
元學(xué)習(xí)是一個學(xué)習(xí)「如何學(xué)習(xí)」的過程。一個元學(xué)習(xí)算法要面對一組任務(wù),其中每一個任務(wù)都是一個學(xué)習(xí)問題;然后算法會產(chǎn)生一個快速學(xué)習(xí)器,這個學(xué)習(xí)器有能力從很小數(shù)目的一組樣本中泛化。小樣本分類(few-shot classification)就是一個得到了充分研究的元學(xué)習(xí)問題,其中的每個任務(wù)都是一個分類問題,這里的學(xué)習(xí)器只能看到每個類別的 1 個到 5 個輸入-輸出樣本,然后它就要開始對新的輸入樣本進行分類。
和 MAML 類似,Reptile 首先會為神經(jīng)網(wǎng)絡(luò)尋找一組初始參數(shù),以便網(wǎng)絡(luò)稍后可以根據(jù)來自新任務(wù)的數(shù)量不多的幾個樣本進行精細調(diào)節(jié)(fine-tune)。不過,相比于 MAML 需要在梯度下降算法的計算圖中展開并求導(dǎo),Reptile 只需要簡單地在每個任務(wù)中以標(biāo)準(zhǔn)方法執(zhí)行隨機梯度下降(SGD),并不需要展開一個計算圖以及計算任何二階導(dǎo)數(shù)。這樣的設(shè)計讓 Reptile 所需的計算資源和存儲資源都比 MAML 更小。Reptile 的偽碼如下所示:
這里的最后一步也有另一種做法,可以把 Φ?W 整體作為梯度,然后把它嵌入進 Adam 之類的更復(fù)雜的優(yōu)化器中。
OpenAI 的研究人員們從一開始就感到驚訝,驚訝的是這個算法居然能運行出結(jié)果。當(dāng) k =1 的時候,這個算法就相當(dāng)于是「聯(lián)合訓(xùn)練」,在所有任務(wù)的混合體中做隨機梯度下降。雖然聯(lián)合訓(xùn)練在某些狀況下可以作為一種有用的初始化手段,但是零樣本學(xué)習(xí)(zero-shot learning)不可用的時候(比如當(dāng)輸出標(biāo)簽被隨機替換了),它所能學(xué)到的東西就非常有限。Reptile 算法中需要 k >1,也就是說,參數(shù)更新依靠的是損失函數(shù)的更高階導(dǎo)數(shù)。正如論文中所示的,此時算法的表現(xiàn)和 k =1 時相比有很大不同。
為了分析為什么 Reptile 會奏效,OpenAI 的研究人員們用泰勒級數(shù)逼近了參數(shù)更新。他們發(fā)現(xiàn) Reptile 的更新可以讓在同一個任務(wù)中不同 minibatch 的梯度的內(nèi)積最大化,這就對應(yīng)了模型的更強的泛化能力。這一發(fā)現(xiàn)也有超出了元學(xué)習(xí)研究領(lǐng)域的指導(dǎo)意義,可能可以用來解釋隨機梯度下降的泛化性質(zhì)。OpenAI 的研究表明 Reptile 和 MAML 執(zhí)行的參數(shù)更新非常詳細,包括其中有兩個相同的項,不過權(quán)重不一樣。
在 OpenAI 的實驗中,Reptile 和 MAML 在 Omniglot 和 Mini-ImageNet 的兩項小樣本學(xué)習(xí) benchmark 中取得了近似的表現(xiàn)。Reptile 收斂到最終解決方案的速度也更快,因為它的更新的方差更小。
OpenAI 對 Reptile 的分析也表明,通過對隨機梯度下降的梯度做不同的組合,我們可以得到許多中不同的算法。假設(shè)每個任務(wù)中使用不同的 minibatch 進行 k 步隨機梯度下降,得到的梯度分別為 g1、g2、……、gk。下圖就展示了在 Omniglot benchmark 中把每種不同的梯度和作為元學(xué)習(xí)的梯度的算法的學(xué)習(xí)曲線。g2 對應(yīng)的是一階 MAML,也就是最初的 MAML 論文中表述的算法。包括的梯度越多,算法學(xué)習(xí)得就越快,因為其中的方差會隨之減小。可以注意到僅僅使用 g1(也就是 k =1 時)并不會給這個任務(wù)帶來什么改進,因為零樣本學(xué)習(xí)無法被改進。
OpenAI 已經(jīng)把 Reptile 的算法實現(xiàn)開源在了 GitHub 上。它的計算借助 TensorFlow 完成,而且開源中也包括了復(fù)現(xiàn) Omniglot 和 Mini-ImageNet 的兩項小樣本學(xué)習(xí) benchmark 的代碼。他們也編寫了一個 JavaScript 的實現(xiàn),模型預(yù)訓(xùn)練仍然由 TensorFlow 完成,然后 JavaScript 根據(jù)樣本做精細調(diào)節(jié)。OpenAI 博客中的算法樣例(也就是本文的開頭動圖)就是借助 JavaScript 實現(xiàn)完成的。PyTorch 實現(xiàn)的完整代碼也一并在博客頁面上給出。
論文地址:https://d4mucfpksywv.cloudfront.net/research-covers/reptile/reptile_update.pdf
開源地址:https://github.com/openai/supervised-reptile
via OpenAI Blog,雷鋒網(wǎng) AI 科技評論編譯
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。