機械学習の基礎まとめ【損失関数(2乗和誤差、交差エントロピー誤差)】

今回は損失関数について、

 

「損失関数ってなに??」

「損失関数にはどんな種類のものがあるの??」

「Pythonで損失関数を実装すると??」

 

などをまとめてみました。

 

ちなみに、損失関数はJDLAのG検定にも出題される重要な内容です。

 

この記事を書いている僕はシステムエンジニア6年目

 

普段はJavaでWebアプリを作ったりSQL書いたり・・・、

なので最近流行りのPython、数学、人工知能、デープラーニングができる人には正直憧れています。。。。

 

 

自分も一から勉強してこの辺りできるようになりたい、、画像認識モデルを作ったりして、アプリに組み込みたい!

これが機械学習、深層学習の勉強を始めたきっかけでした。

 

 

体系的に、この分野の基礎から学ぼうとJDLAのG検定の勉強をして合格するところまでいきました。

次のステップとして、

実際にPythonでコードを書きながら機械学習や深層学習の知識を深めているところです。。。

 

 

今回は、JDLAのG検定にも出題される「損失関数」についてまとめてみました。

 

 

ニューラルネットワークの復習

 

 

損失関数について確認する前に、

まずはニューラルネットワークについて復習します。

 

ニューラルネットワークでは、

例えば、

以下の図のように手書き数字が何の数字かを推論(順方向への伝播)することができました。

 

 

詳細は以下にまとめてあります。

機械学習の基礎まとめ【ニューラルネットワークの推論(順方向への伝播)】

実は、ニューラルネットワークは推論だけでなく学習もできます。

 

学習とは、ニューラルネットワークの重みやバイアスを最適な値(精度が高くなるように)に自動で更新出来ることを指します。

 

 

上の例だと、手書き数字の2を入力されたら2が100%、それ以外を0%と推論出来る必要があります。

その推論結果に近づけるようにニューラルネットワークのバイアスと重みを学習によって最適な値に更新していきます。

その学習の際に必要なのが以降で説明する損失関数です。

 

 

損失関数

 

 

損失関数はニューラルネットワークの性能の「悪さ」を示す指標です。

この指標を基準にニューラルネットワークの学習では重みを更新していきます。

 

先ほどの例でいうと、、

推論結果が手書き数字2を100%、その他を0%と推論できれば性能の悪さが低くなり、

そのほかを1%、2%と推測してしまえば高くなります。

 

ややこしいかもしれませんが、

値が大きいほど性能が悪いことを表します。(値にマイナスをかければ逆になりますが、、)

 

損失関数には一般的に2乗和誤差や交差エントロピー誤差などを用います。

 

2乗和誤差

損失関数として用いられる有名なものに2乗和誤差(mean squared error)があります。

数式で表すと、

$$E = \displaystyle \frac{1}{2}\sum_{k} (y_{k} – t_{k})^2$$

 

上の式において、\(y_k\)はニューラルネットワークの出力、\(t_k\)は教師データを表します。

上の式は教師データ(正解)の値とニューラルネットワークの出力との差が大きいほどE(error)は大きくなることを表しています。

 

例えば、先ほどの例では、\(y_0\)がニューラルネットワークを「0」の確率として出力した値で、

\(t_2\)は「2」の正解の値(教師データ)です。

データとしてはこのような値が入ります。

Python

# softmax関数の出力値をイメージ
y=[0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
# 教師データをイメージ(2が正解のone-hot表現)
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]

 

MEMO

one-hot表現とは教師データにおいて、正解ラベルのみ「1」、それ以外を「0」とするデータの表現です。

例えば、手書き数字(0~9)の2が正解ラベルの場合、

普通の「2」と表すのではなく、

「[0, 0, 1, 0, 0, 0, 0, 0, 0]」と配列で表します。

 

 

2乗和誤差をPythonで実装すると、

Python
import numpy as np
# 2乗和誤差
def mean_squared_error(y, t):
  return 0.5 * np.sum((y - t) ** 2)

# softmax関数の出力値をイメージ(2の確率が最も高い)
y1=[0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
# 教師データをイメージ(2が正解のone-hot表現)
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
print("例1(2と推測、答えは2):" + str(mean_squared_error(np.array(y1), np.array(t))))
# softmax関数の出力値をイメージ(7の確率が最も高い)
y2=[0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print("例2(7と推測、答えは2):" + str(mean_squared_error(np.array(y2), np.array(t))))

 

 

例1はニューラルネットワークが手書き数字「2」と推測し、教師データもone-hot表現で「2」が答えのケースをイメージしています。

一方例2はニューラルネットワークが手書き数字「7」と推測し、教師データはone-hot表現で「7」が答えのケースをイメージしています。

 

例2は推論結果と教師データが異なるため値が大きくなっていることがわかります。

 

2乗和誤差は、

ニューラルネットワークが出力した結果と教師データの値との誤差の総和が出力されます。

 

交差エントロピー誤差

損失関数として交差エントロピー誤差(cross entropy error)もよく用いられます。

数式で表すと、

$$E = \displaystyle -\sum_{k} (t_{k}logy_{k})$$

上の式において、\(y_k\)はニューラルネットワークの出力、\(t_k\)は教師データを表します。

また、\(log\)は底が自然対数の\(log_{e}\)を表します。

上の式は教師データ(正解)の値とニューラルネットワークの出力との差が大きいほどE(error)は大きくなることを表しています。

 

先ほどと同じく、教師データはone-hot表現であるとします。

そのため、正解ラベル以外は0を掛けるため、無視できます。

\(0 \times log0.1 + 0 \times log0.2 + 1 \times log0.6・・・\)

 

つまり、正解ラベルのニューラルネットワークの出力結果がどうかで値は決まります。

 

 

\(y = logx\)は\(x\)が1に近いほど0に近づき(誤差が小さい)、0に近いほど\(-\infty\)に近づき(誤差が大きい)ます。

 

 

交差エントロピー誤差をPythonで実装すると、

Python
import numpy as np
# 2乗和誤差
def cross_entropy_error(y, t):
  delta = 1e-7 # マイナス無限大対策
  return -np.sum(t * np.log(y + delta))

# softmax関数の出力値をイメージ(2の確率が最も高い)
y1=[0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
# 教師データをイメージ(2が正解のone-hot表現)
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
print("例1(2と推測、答えは2):" + str(cross_entropy_error(np.array(y1), np.array(t))))
# softmax関数の出力値をイメージ(7の確率が最も高い)
y2=[0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print("例2(7と推測、答えは2):" + str(cross_entropy_error(np.array(y2), np.array(t))))

 

 

例1、例2は先ほどと同じように推論が間違えていると値は大きくなっていることが分かります。

 

交差エントロピー誤差は、

教師データの正解の値に対するニューラルネットワーク出力の誤差の大きさのみを出力します。

 

まとめ

 

 

今回は損失関数について確認しました。

 

以下は抑えておきましょう。

  • ニューラルネットワークは学習の際、損失関数の値を基準として重みやバイアスの最適な値を探す
  • 損失関数には2乗和誤差と交差エントロピー誤差があり、ニューラルネットワークの出力と教師データの値の差が大きいほど大きな値を出力する

 

参考にした資料

 

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください