在前面幾節(jié)課的代碼演示中,我們都是先通過(guò)模型的正向傳播(forward propagation)對(duì)輸入計(jì)算模型輸出,利用損失函數(shù)得出預(yù)測(cè)值和真實(shí)值的損失值,然后使用反向傳播算法(back-propagation)快速計(jì)算目標(biāo)函數(shù)關(guān)于各個(gè)參數(shù)的梯度,最后使用隨機(jī)梯度下降算法(Stochastic Gradient Descent,SGD)基于前面得到的梯度值計(jì)算loss函數(shù)局部最小值,從而求解權(quán)重并更新網(wǎng)絡(luò)參數(shù)。
模型訓(xùn)練關(guān)鍵步驟
基于反向傳播算法(back-propagation)的 動(dòng)求梯度極 簡(jiǎn)化了深度學(xué)習(xí)模型訓(xùn)練算法的實(shí)現(xiàn)。
這 節(jié)我們將使 數(shù)學(xué)和計(jì)算圖(computational graph)兩個(gè) 式來(lái)描述正向傳播和反向傳播。具體地,我們將以帶 L2 范數(shù)正則化的單隱藏層感知機(jī)為樣例模型解釋正向傳播和反向傳播。
1.正向傳播
正向傳播是指對(duì)神經(jīng) 絡(luò)沿著從輸 層到輸出層的順序,依次計(jì)算并存儲(chǔ)模型的中間變量(包括輸出)。為簡(jiǎn)單起 ,假設(shè)輸 是 個(gè)特征為 x Rd 的樣本,且不考慮偏差項(xiàng),那么中間變量
網(wǎng)絡(luò)計(jì)算的中間變量
其中W(1) Rh d 是隱藏層的權(quán)重參數(shù)。把中間變量z Rh 輸 按元素操作的激活函數(shù)? 后,
我們將得到向量 度為h 的隱藏層變量
h = ?(z).
隱藏變量h 也是 個(gè)中間變量。假設(shè)輸出層參數(shù)只有權(quán)重W(2) Rq h,我們可以得到向量 度
為q 的輸出層變量
o = W(2)h.
假設(shè)損失函數(shù)為?,且樣本標(biāo)簽為y,我們可以計(jì)算出單個(gè)數(shù)據(jù)樣本的損失項(xiàng)
L = ?(o, y)
根據(jù)L2 范數(shù)正則化的定義,給定超參數(shù)λ,正則化項(xiàng)即
正則化懲罰項(xiàng)
其中矩陣的Frobenius 范數(shù)等價(jià)于將矩陣變平為向量后計(jì)算L2 范數(shù)。最終,模型在給定的數(shù)據(jù)
樣本上帶正則化的損失為
J = L + s.
我們將J 稱為有關(guān)給定數(shù)據(jù)樣本的 標(biāo)函數(shù),并在以下的討論中簡(jiǎn)稱 標(biāo)函數(shù)。
2.正向傳播的計(jì)算圖
通常繪制計(jì)算圖來(lái)可視化運(yùn)算符和變量在計(jì)算中的依賴關(guān)系,一般來(lái)說(shuō),計(jì)算圖中左下角是輸入,右上角是輸出。其中方框代表變量,圓圈代表運(yùn)算符,箭頭表示從輸入到輸出之間的依賴關(guān)系。
正向傳播的計(jì)算圖
3.反向傳播
反向傳播指的是計(jì)算神經(jīng)網(wǎng)絡(luò)參數(shù)梯度的方法。總的來(lái)說(shuō),反向傳播依據(jù)微積分中的鏈?zhǔn)?/a>法則,沿著從輸出層到輸入層的順序,依次計(jì)算并存儲(chǔ)目標(biāo)函數(shù)有關(guān)神經(jīng)網(wǎng)絡(luò)各層的中間變量以及參數(shù)的梯度。對(duì)輸入或輸出X , Y , Z 為任意形狀張量的函數(shù)Y = f ( X ) 和Z = g ( Y ) ,通過(guò)鏈?zhǔn)椒▌t,有:
鏈?zhǔn)椒▌t求導(dǎo)
其中prod運(yùn)算將根據(jù)兩個(gè)輸入的形狀,在必要的操作(如轉(zhuǎn)置和互換輸入位置)后對(duì)兩個(gè)輸入做乘法。
例中的模型,它的參數(shù)是W(1) 和W(2) ,因此反向傳播的目標(biāo)是計(jì)算目標(biāo)函數(shù)對(duì)參數(shù)的導(dǎo)數(shù) J/ W(1)和 J/ W(2)。
應(yīng)用鏈?zhǔn)椒▌t則依次計(jì)算各中間變量和參數(shù)的梯度,其計(jì)算次序與前向傳播中相應(yīng)中間變量的計(jì)算次序恰恰相反。
首先,分別計(jì)算目標(biāo)函數(shù)J = L + s 有關(guān)損失項(xiàng)L和正則項(xiàng)s 的梯度:
其次,依據(jù)鏈?zhǔn)椒▌t計(jì)算目標(biāo)函數(shù)有關(guān)輸出層變量的梯度 J/ o Rq:
接下來(lái),計(jì)算正則項(xiàng)有關(guān)兩個(gè)參數(shù)的梯度:
現(xiàn)在,我們可計(jì)算最靠近輸出層的模型參數(shù)的梯度 J/ W(2) Rq h。依據(jù)鏈?zhǔn)椒▌t,得到:
沿著輸出層向隱藏層繼續(xù)反向傳播,隱藏層變量的梯度 J/ h Rh 計(jì)算如下:
由于激活函數(shù)?是按元素運(yùn)算的,中間變量z 的梯度 J/ z Rh的計(jì)算需要使用按元素乘法符 :
最終,可以得到最靠近輸入層的模型參數(shù)的梯度 J/ W(1) Rh d。依據(jù)鏈?zhǔn)椒▌t,得到:
4.正向傳播和反向傳播的訓(xùn)練關(guān)系
在訓(xùn)練深度學(xué)習(xí)模型時(shí),正向傳播和反向傳播之間相互依賴。
一方面,正向傳播的計(jì)算可能依賴于模型參數(shù)的當(dāng)前值。而這些模型參數(shù)是在反向傳播的梯度計(jì)算后通過(guò)優(yōu)化算法迭代的。
例如,計(jì)算正則化項(xiàng)
依賴模型參數(shù)W(1) 和 W(2) 的當(dāng)前值。而這些當(dāng)前值是優(yōu)化算法最近 次根據(jù)反向傳播算出梯度后迭代得到的。
另一方面,反向傳播的梯度計(jì)算可能依賴于各變量的當(dāng)前值。而這些變量的當(dāng)前值是通過(guò)正向傳播計(jì)算得到的。舉例來(lái)說(shuō),參數(shù)梯度 J/ W(2) = ( J/ o)hT + λW(2) 的計(jì)算需要依賴隱藏層變量的當(dāng)前值 h。這個(gè)當(dāng)前值是通過(guò)從輸 層到輸出層的正向傳播計(jì)算并存儲(chǔ)得到的。
因此,在模型參數(shù)初始化完成后,我們交替地進(jìn) 正向傳播和反向傳播,并根據(jù)反向傳播計(jì)算的梯度迭代模型參數(shù)。既然我們?cè)诜聪騻鞑ブ惺?了正向傳播中計(jì)算得到的中間變量來(lái)避免重復(fù)計(jì)算,那么這個(gè)重 也導(dǎo)致正向傳播結(jié)束后不能 即釋放中間變量?jī)?nèi)存。這也是訓(xùn)練要 預(yù)測(cè)占 更多內(nèi)存的 個(gè)重要原因。另外需要指出的是,這些中間變量的個(gè)數(shù)跟 絡(luò)層數(shù)線性相關(guān),每個(gè)變量的 小跟批量 小和輸 個(gè)數(shù)也是線性相關(guān)的,它們是導(dǎo)致較深的神經(jīng) 絡(luò)使 較 批量訓(xùn)練時(shí)更容易超內(nèi)存的主要原因。
總結(jié)
正向傳播沿著從輸 層到輸出層的順序,依次計(jì)算并存儲(chǔ)神經(jīng) 絡(luò)的中間變量。
反向傳播沿著從輸出層到輸 層的順序,依次計(jì)算并存儲(chǔ)神經(jīng) 絡(luò)中間變量和參數(shù)的梯度。
所謂反向傳播,傳播的是損失,也就是根據(jù)最后的損失,計(jì)算網(wǎng)絡(luò)中每一個(gè)節(jié)點(diǎn)的梯度,這里利用了鏈?zhǔn)椒▌t,使得梯度的計(jì)算并不是很復(fù)雜。
在訓(xùn)練深度學(xué)習(xí)模型時(shí),正向傳播和反向傳播相互依賴。