2
本文作者: 小東 | 2017-01-12 11:43 |
雷鋒網(wǎng)按:如果對人工智能稍有了解的小伙伴們,或多或少都聽過反向傳播算法這個名詞,但實際上BP到底是什么?它有著怎樣的魅力與優(yōu)勢?本文發(fā)布于 offconvex.org,作者 Sanjeev Arora與 Tengyu Ma,雷鋒網(wǎng)對此進行了編譯,未經(jīng)許可不得轉(zhuǎn)載。
目前網(wǎng)絡上關于反向傳播算法的教程已經(jīng)很多,那我們還有必要再寫一份教程嗎?答案是‘需要’。
為什么這么說呢?我們教員Sanjeev最近要給本科生上一門人工智能的課,盡管網(wǎng)上有很多反向傳播算法的教程,但他卻找不到一份令他滿意的教程,因此我們決定自己寫一份關于反向傳播算法的教程,介紹一下反向傳播算法的歷史背景、原理、以及一些最新研究成果。
PS:本文默認讀者具備一定的基礎知識(如了解梯度下降、神經(jīng)網(wǎng)絡等概念)。
反向傳播算法是訓練神經(jīng)網(wǎng)絡的經(jīng)典算法。在20世紀70年代到80年代被多次重新定義。它的一些算法思想來自于60年代的控制理論。
在輸入數(shù)據(jù)固定的情況下、反向傳播算法利用神經(jīng)網(wǎng)絡的輸出敏感度來快速計算出神經(jīng)網(wǎng)絡中的各種超參數(shù)。尤其重要的是,它計算輸出f對所有的參數(shù)w的偏微分,即如下所示:?f/?wi,f代表神經(jīng)元的輸出,wi是函數(shù)f的第i個參數(shù)。參數(shù)wi代表網(wǎng)絡的中邊的權(quán)重或者神經(jīng)元的閾值,神經(jīng)元的激活函數(shù)具體細節(jié)并不重要,它可以是非線性函數(shù)Sigmoid或RELU。這樣就可以得到f相對于網(wǎng)絡參數(shù)的梯度?f ,有了這個梯度,我們就可以使用梯度下降法對網(wǎng)絡進行訓練,即每次沿著梯度的負方向(??f)移動一小步,不斷重復,直到網(wǎng)絡輸出誤差最小。
在神經(jīng)網(wǎng)絡訓練過程中,我們需要注意的是,反向傳播算法不僅需要準確計算梯度。還需要使用一些小技巧對我們的網(wǎng)絡進行訓練。理解反向傳播算法可以幫助我們理解那些在神經(jīng)網(wǎng)絡訓練過程中使用的小技巧。
反向傳播算法之所以重要,是因為它的效率高。假設對一個節(jié)點求偏導需要的時間為單位時間,運算時間呈線性關系,那么網(wǎng)絡的時間復雜度如下式所示:O(Network Size)=O(V+E),V為節(jié)點數(shù)、E為連接邊數(shù)。這里我們唯一需要用的計算方法就是鏈式法則,但應用鏈式法則會增加我們二次計算的時間,由于有成千上萬的參數(shù)需要二次計算,所以效率就不會很高。為了提高反向傳播算法的效率,我們通過高度并行的向量,利用GPU進行計算。
注:業(yè)內(nèi)人士可能已經(jīng)注意到在標準的神經(jīng)網(wǎng)絡訓練中,我們實際上關心的是訓練損失函數(shù)的梯度,這是一個與網(wǎng)絡輸出有關的簡單函數(shù),但是上文所講的更具有普遍意義,因為神經(jīng)網(wǎng)絡是可以增加新的輸出節(jié)點的,此時我們要求的就是新的網(wǎng)絡輸出與網(wǎng)絡超參數(shù)的偏微分。
反向傳播算法適用于有向非循環(huán)網(wǎng)絡,為了不失一般性,非循環(huán)神經(jīng)網(wǎng)絡可以看做是一個多層神經(jīng)網(wǎng)絡,第t+1層神經(jīng)元的輸入來自于第t層及其下層。我們使用f表示網(wǎng)絡輸出,在本文中我們認為神經(jīng)網(wǎng)絡是一個上下結(jié)構(gòu),底部為輸入,頂部為輸出。
規(guī)則1:為了先計算出參數(shù)梯度,先求出 ?f/?u ,即表示輸出f對節(jié)點u的偏微分。
我們使用規(guī)則1來簡化節(jié)點偏微分計算。下面我將具體說一下?f/?u的含義。我們做如下假設,先刪除節(jié)點u的所有輸入節(jié)點。然后保持網(wǎng)絡中的參數(shù)不變。現(xiàn)在我們改變u的值,此時與u相連的高層神經(jīng)元也會受到影響,在這些高層節(jié)點中,輸出f也會受到影響。那么此時?f/?u就表示當節(jié)點u變化時,節(jié)點f的變化率。
規(guī)則1就是鏈式法則的直接應用,如下圖所示,u是節(jié)點 z1,…,zm的加權(quán)求和,即u=w1*z1+?+wn*zn,然后通過鏈式法則對w1求偏導數(shù),具體如下:
由上式所示,只有先計算?f/?u,然后才能計算?f/?w1。
為了計算節(jié)點的偏微分,我們先回憶一下多元鏈式法則,多元鏈式法則常用來描述偏微分之間的關系。 即假設f是關于變量u1,…,un的函數(shù),而u1,…,un又都是關于變量z的函數(shù),那么f關于z的偏導數(shù)如下:
這是鏈式法則2的一般式,是鏈式法則的1的子式。這個鏈式法則很適合我們的反向傳播算法。下圖就是一個符合多元鏈式法則的神經(jīng)網(wǎng)絡示意圖。
如上圖所示,先計算f相對于u1,…,un的偏導數(shù),然后將這些偏導數(shù)按權(quán)重線性相加,得到f對z的偏導數(shù)。這個權(quán)重就是u1,…,un對z的偏導,即?uj/?z。此時問題來了,我么怎么衡量計算時間呢?為了與教課書中保持一致,我們做如下假設:u節(jié)點位于t+1層的,z節(jié)點位于t層或t層以下的子節(jié)點,此時我們記?u/?z的運算時間為單位時間。
我們首先要指出鏈式法則是包含二次計算的時間。許多作者都不屑于講這種算法,直接跳過的。這就好比我們在上算法排序課時,老師都是直接講快速排序的,像那些低效排序算法都是直接跳過不講的。
樸素算法就是計算節(jié)點對ui與uj之間偏導數(shù),在這里節(jié)點ui的層級要比uj高。在V*V個節(jié)點對的偏導值中包含?f/?ui的值,因為f本身就是一個節(jié)點,只不過這個節(jié)點比較特殊,它是一個輸出節(jié)點。
我們以前饋的形式進行計算。我們計算了位于t層及t層以下的所有節(jié)點對之間的偏導數(shù),那么位于t+1層的ul對uj的偏導數(shù)就等于將所有ui與uj的偏導數(shù)進行線性加權(quán)相加。固定節(jié)點j,其時間復雜度與邊的數(shù)量成正比,而j是有V個值,此時時間復雜度為O(VE)。
反向傳播算法如其名所示,就是反向計算偏微分,信息逆向傳播,即從神經(jīng)網(wǎng)絡的高層向底層反向傳播。
信息協(xié)議:節(jié)點u通過高層節(jié)點獲取信息,節(jié)點u獲取的信息之和記做S。u的低級節(jié)點z獲取的信息為S??u/?z
很明顯,每個節(jié)點的計算量與其連接的神經(jīng)元個數(shù)成正比,整個網(wǎng)絡的計算量等于所有節(jié)點運算時間之和,所有節(jié)點被計算兩次,故其時間復雜度為O(Network Size)。
我們做如下證明:S等于?f/?z。
證明如下:當z為輸出層時,此時?f/?z=?f/?f=1
假如對于t+1層及其高層假設成立,節(jié)點u位于t層,它的輸出邊與t+1層的u1,u2,…,um節(jié)點相連,此時節(jié)點從某個節(jié)點j收到的信息為(?f/?uj)×(?uj/?z),根據(jù)鏈式法則,節(jié)點z收到的總信息為S=
在上文中,關于神經(jīng)網(wǎng)絡、節(jié)點計算,我們并沒有細講。下面我們將具體講一下,我們將節(jié)點與節(jié)點之間的計算看做是一個無環(huán)圖模型,許多自動計算微分的工具包(如:autograd,tensorflow)均采用這一模型。這些工具就是通過這個無向圖模型來計算輸出與網(wǎng)絡參數(shù)的偏導數(shù)的。
我們首先注意到法則1就是對這個的一般性描述,這個之所以不失一般性是因為我們可以將邊的權(quán)值也看做節(jié)點(即葉節(jié)點)。這個很容易轉(zhuǎn)換,如下圖所示,左側(cè)是原始網(wǎng)絡,即一個單節(jié)點和其輸入節(jié)點、輸入節(jié)點的權(quán)重。右側(cè)是將邊的權(quán)重轉(zhuǎn)換為葉節(jié)點。網(wǎng)絡中的其它節(jié)點也做類似轉(zhuǎn)換。
只要局部偏導數(shù)計算的效率足夠高,那么我們就可以利用上文所說的信息協(xié)議來計算各個節(jié)點的偏微分。即對節(jié)點u來講,我們應該先找出它的的輸入節(jié)點有哪些,即z1,…,zn。然后計算在u的偏微分的基礎上計算zj的偏微分,由于輸出f對u的偏微分記做S,所以計算輸出f對zj的偏微分就是S??u?zj
這個算法可以按照如下規(guī)則分塊計算,首先明確節(jié)點u與輸入節(jié)點z1,…,zn 的關系,然后就是怎么計算偏導數(shù)的倍數(shù)(權(quán)重)S。即S??u/?zj。
擴展到向量空間:為了提高偏微分權(quán)重的計算效率,我們可以將節(jié)點的輸出也變?yōu)橐粋€向量(矩陣或張量)。此時我們將?u/?zj?S改寫為?u/?zj[S], 這個與我們的反向傳播算法思想是一致的,在反向傳播算法中,y是一個p維向量,x是一個q維向量,y是關于x的函數(shù),我們用?y/?x來表示由 ?yj/?xi所組成的q*p矩陣。聰明的讀者很快就會發(fā)現(xiàn),這就是我們數(shù)學中的雅克比矩陣。此外我們還可以證明S與u的維度相同、?u?zj[S] 與zj的維度也相同。
如下圖所示,W是一個d2*d3的矩陣,Z是一個d1*d2的矩陣,U=WZ故U是一個d1*d3維的矩陣,此時我們計算?U/?Z,最終得到一個d2d3×d1d3維的矩陣。但我們在反向傳播算法中,這個會算的很快,因為?U/?Z[S]=W?S,在計算機中我們可以使用GPU來進行類似向量計算。
在許多神經(jīng)網(wǎng)絡框架中,設計者想要是一些神經(jīng)元的參數(shù)能夠共享,這些參數(shù)包括邊的權(quán)重或者節(jié)點的閾值參數(shù)。例如,在卷積神經(jīng)網(wǎng)絡中,同一個卷集核使用的參數(shù)都是一樣的。簡而言之,就是a、b是兩個不同的參數(shù),但我們強制要求a與b的值相同,即參數(shù)共享。這就好比我們給神經(jīng)網(wǎng)絡新增一個節(jié)點u,并且節(jié)點u與a和b相連,并且a=u,b=u.,此時根據(jù)鏈式法則,?f/?u=(?f/?a)?(?a/?u)+(?f/?b)?(?b/?u)=?f/?a+?f/?b. 因此,對一個共享參數(shù)而言,其梯度就是輸出與參數(shù)節(jié)點之間的中間節(jié)點的偏導數(shù)之和。
上面我們講的是非循環(huán)神經(jīng)網(wǎng)絡,許多前沿應用(機器翻譯、語言理解)往往使用有向循環(huán)神經(jīng)網(wǎng)絡。在這種結(jié)構(gòu)的神經(jīng)網(wǎng)絡中會存在記憶單元或注意力機制,在這些單元或機制中往往存在復雜的求導計算。一開始我們使用梯度下降法訓練網(wǎng)絡,即在時間序列上對神經(jīng)網(wǎng)絡使用反向傳播算法,即對這個有向環(huán)狀結(jié)構(gòu)進行無限循環(huán),每一次循環(huán)的網(wǎng)絡結(jié)構(gòu)、網(wǎng)絡參數(shù)都是一樣的,但是網(wǎng)絡的輸入與輸出是不一樣的。在實際應用中我們會遇到梯度爆炸或梯度消失等問題,這些都會對結(jié)果收斂產(chǎn)生影響。為了解決這些問題,我們使用梯度剪切或者長短記憶模型(LSTM)等技術解決上述問題。
環(huán)狀神經(jīng)網(wǎng)絡可以高效計算梯度的事實促進了有記憶網(wǎng)絡甚至數(shù)據(jù)結(jié)構(gòu)的發(fā)展。使用梯度下降法,我們可可以對環(huán)狀結(jié)構(gòu)神經(jīng)網(wǎng)絡進行優(yōu)化,尋找最佳參數(shù),使得這個網(wǎng)絡可以解決特定計算問題。梯度下降法的極限目前仍在探索中。
在近似線性時間中,我們不僅可以使用梯度下降法,或許我們也可以使用2階導數(shù)對目標函數(shù)進行優(yōu)化。在優(yōu)化過程中,最關鍵的一步是計算海森矩陣與一個向量的積,下面我將向大家介紹如何在規(guī)模是O(Network size)的神經(jīng)網(wǎng)絡應用上述思想,這個例子與前面所講稍有不同,我們的初始神經(jīng)網(wǎng)絡應該是一個用反向傳播算法進行簡單優(yōu)化過的神經(jīng)網(wǎng)絡。
法則:假設在無環(huán)神經(jīng)網(wǎng)絡中,有V個節(jié)點,E條邊,網(wǎng)絡輸出為f,葉節(jié)點為z1,…,zm,那么必存在一個大小為O(V+E)的網(wǎng)絡,這個網(wǎng)絡的的輸入節(jié)點為z1,…,zm,輸出節(jié)點為?f/?z1,…,?f/?zm。
上面的定理可以通過在無環(huán)神經(jīng)網(wǎng)絡中實現(xiàn)消息直接傳遞來證明,緊接著我們將解釋一下如何計算?2f(z)?v。設g(z)=??f(z),v? ,有定理可知, g(z)可以由大小是O(V+E)神經(jīng)網(wǎng)絡計算得到,同理我們再次應用法則,在這個大小是O(V+E)的網(wǎng)絡計算g(z)的梯度,此時?g(z)=?2f(z)?v,此時我們就算出了海森矩陣與向量積的積,此時耗費的時間復雜度就是網(wǎng)絡規(guī)模的大小。
以上便是BP學習過程中需要了解的一些內(nèi)容,雷鋒網(wǎng)希望能讓你在學習過程中得到一個比較清晰的思路。當然,也歡迎你關注雷鋒網(wǎng)旗下公眾號“AI科技評論”與我們交流哦。
via Back-propagation, an introduction
雷峰網(wǎng)原創(chuàng)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。