機械学習の基礎【誤差逆伝播法の仕組みを計算グラフで理解する】

この記事を読むのに必要な時間は約 12 分です。

\(\require{cancel}\)

今回は誤差逆伝播法の仕組みについて計算グラフを使いながら解説します。

 

この記事を書いている僕はシステムエンジニア6年目

 

普段はJavaでWebアプリを作ったりSQL書いたり・・・、

なので最近流行りのPython、数学、人工知能、デープラーニングができる人には正直憧れています。。。。

 

 

自分も一から勉強してこの辺りできるようになりたい、、画像認識モデルを作ったりして、アプリに組み込みたい!

これが機械学習、深層学習の勉強を始めたきっかけでした。

 

 

体系的に、この分野の基礎から学ぼうとJDLAのG検定の勉強をして合格するところまでいきました。

次のステップとして、

実際にPythonでコードを書きながら機械学習や深層学習の知識を深めているところです。。。

 

 

前回はニューラルネットワークの勾配降下法を利用した学習について確認しました。

機械学習の基礎まとめ【勾配降下法を利用したニューラルネットワークの学習】

 

勾配降下法では重みパラメータに関する損失関数の勾配を数値微分によって求めました。

数値微分はシンプルで、実装が簡単でしたが、計算に時間がかかります。。。

 

そこで、今回は重みパラメータの勾配の計算を効率良く行う手法である誤差逆伝播法の仕組みについて確認します。

また、誤差逆伝播法を正しく理解するために計算グラフをを用います。

 

 

計算グラフ

 

 

誤差逆伝播法の前に、計算グラフについて確認していきます。

まずは、以下の問題を計算グラフで解いてみます。

 

問.りんご2個を買いました。りんごは1個100円、消費税が10%かかるものとして、支払い金額を求めなさい。

 

計算グラフは以下のようになります。

 

 

計算グラフを見ると、

簡単ですが、りんごの100円と個数の2個が掛け合わされ200円に、

さらに消費税の1.1(10%)が掛け合わされ最終的に220円になっていることが分かります。

 

このように、計算グラフ上で計算を左から右へ進めるのを順方向への伝播、略して順伝播(forward propagation)と言います。

逆に、右から左へ計算を進めるのを逆伝播と言い、この先で説明します。

 

 

計算グラフの特徴

計算グラフの特徴は「局所的な計算」を伝播することによって最終的な結果を得ることができる点です。

計算グラフの特徴は計算グラフ全体でどのようなことが行われようと、自分に関係する情報だけから次の結果を出力することができるのです。

 

例えば、りんご2個とそれ以外のいろんな買い物をした場合、

計算グラフを以下のように表すことができます。

 

 

上の図を見るとりんごの計算はそれ以外のいろんな買い物の詳細を知らなくても計算できます。

複雑な計算の結果である4,000円をりんごの計算結果に足し合わせるだけでいいのです。

 

このように、計算グラフでは例え全体の計算がどんなに複雑であったとしても、各ステップでやることは、

対象とするノードの「局所的な計算」です。その結果を伝達することで全体を構成する複雑な計算結果が得られます。

 

 

計算グラフを使う意味

計算グラフを使う意味は、

計算グラフでは計算過程の値(りんごを2個買った時の200円など)を保持でき、

逆伝播によって「微分」を効率良く計算できる点にあります。

 

それでは、計算グラフの逆伝播について確認していきます。

例えば、先ほどのりんごの計算の例で、

「りんごの値段が値上がりした場合、最終的な支払い金額にどのように影響するのか」を知りたいとします。

 

これは、「りんごの値段に関する支払い金額の微分」を求めることに相当します。

記号で表すと、

りんごの値段を\(x\)、支払い金額を\(L\)とした場合\(\frac{ \partial L }{ \partial x }\)を求めます。

 

この微分の値は、

りんごの値段が「少しだけ」値上がりした場合に、支払い金額がどれだけ増加するかを表したものになります。

 

りんごの値段に関する支払い金額の微分のような値は、

以下のように計算グラフで逆方向の伝播を行えば求めることができます。(一先ず結果のみ)

 

 

上の例だと逆伝播は右から左へ「1→1.1→2.2」と微分の値が伝達されていき、

この結果からりんごの値段に関する支払い金額の微分の値は2.2ということがわかります。

 

これはりんごが1円値上がりしたら、最終的な支払い金額が2.2円増えることを意味します。

(正確には、りんごの値段がある微小な値だけ増加したら、最終的な金額はその微小な値の2.2倍だけ増加することを意味する)

 

 

ここでは、りんごの値段に関する微分を求めましたが、

「消費前に関する支払い金額の微分」「りんごの個数に関する支払い金額の微分」

なども同様の手順で求めることができ、

 

その際には、途中まで求めた微分(図の例だと1や1.1など)の結果を共有することができ、

効率良く複数の微分を計算することができます。

 

このように、計算グラフの利点は、順伝播と逆伝播によって、

各変数の微分の値を効率良く求めることができる点にあります。

 

 

計算グラフの逆伝播の原理

先ほどのりんごの例で逆伝播について触れましたが、

ここでは逆伝播の原理について確認します。

 

まず、\(z = (x + y)^{2}\)という式について考えます。

この式は以下の2つの式で構成されます。

$$z = t^{2}$$

$$t = x + y$$

これは合成関数と呼ばれます。

 

この式について、\(\frac{ \partial z }{ \partial x }\)(\(x\)に関する\(z\)の微分)を求めます。

これは、\(\frac{ \partial z }{ \partial t }\)(\(t\)に関する\(z\)の微分)と\(\frac{ \partial t }{ \partial x }\)(\(x\)に関する\(t\)の微分)の積で表せます。

これは連鎖律と呼ばれます。

数式にすると、

$$\frac{ \partial z }{ \partial x }=\frac{ \partial z }{ \partial t }\frac{ \partial t}{ \partial x }$$

 

ちょうど\(\partial t\)が打ち消しあいます。

$$=\frac{ \partial z }{ \cancel{\partial t} }\frac{ \cancel{\partial t} }{ \partial x }$$

 

連鎖律を使って、

\(z = (x + y)^{2}\)の\(\frac{ \partial z }{ \partial x }\)(\(x\)に関する\(z\)の微分)を求めると、

$$\frac{ \partial z }{ \partial t } = 2t$$

$$\frac{ \partial t }{ \partial x } = 1$$

$$\frac{ \partial z}{ \partial x } = \frac{ \partial z}{ \partial t }\frac{ \partial t}{ \partial x } = 2t \times 1 = 2(x + y)$$

となります。

 

では次に、

連鎖律を使って行った計算を、計算グラフで表すと、

 

 

図を見ると、

計算グラフの逆伝播は右から左へ、

ノードの入力信号に対して、ノードの局所的な微分(偏微分)を乗算して次のノードへ伝播していることが分かります。

 

一番左の結果を見ると、\(\frac{ \partial z}{ \partial z }\frac{ \partial z}{ \partial t }\frac{ \partial t}{ \partial x } = 1 \times \frac{ \partial z}{ \partial t }\frac{ \partial t}{ \partial x } = \frac{ \partial z}{ \partial x}\)が成り立ち、「\(x\)に関する\(z\)の微分」であることが分かります。

 

このように逆伝播が行っていることは、連鎖律の原理から構成されています。

 

 

まとめ

 

 

今回は誤差逆伝播法の仕組みについて計算グラフを利用しながら確認しました。

誤差逆伝播法の雰囲気は掴めたかと思います。

 

次回は、グラフのノードに加算、乗算、内積、シグモイド、ソフトマックスなどを組み込んで

複雑なニューラルネットワークを作って、その順伝播、逆伝播の確認をしたいと思います。

 

 

参考にした資料

 

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください