見出し画像

やってみよう!機械学習 - 勾配降下法

何はともあれ以下のサイトを参考にColabで実行します。

少しコードを削除したも。以下を実行します。

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np


def objective(x):
    return x**2

def differential(func, x):
    h = 1e-4
    diff = (func(x + h) - func(x)) / h
    return diff

if __name__ == '__main__':
    x = np.arange(-10, 10, 0.1)
    y = objective(x)

    ans = 100
    lr = 0.1
    plots = []
    fig = plt.figure()
    plt.plot(x, y)
    for i in range(30):
        diff = differential(objective, ans)
        ans -= lr * diff
        line, = plt.plot(np.sqrt(ans), ans, color='m', marker='X')
        plots.append([line])

    plt.show()


見事に最小値までの軌跡が表示されています。

グラフ化している関数です。

def objective(x):
  return x**2

いわゆる

$${y = x^2}$$

のグラフです。機械学習の場合はこの式が損失関数になります。この損失関数の値を最小にするパラメータを算出するときに使います。

損失関数での値を最小にするために微分して傾きを求めます。

def differential(func, x):
    h = 1e-4
    diff = (func(x + h) - func(x)) / h
    return diff

そして

if __name__ == '__main__':

が実際に実行するコードです。

   x = np.arange(-10, 10, 0.1)
   y = objective(x)
   fig = plt.figure()
   plt.plot(x, y)

これで$${y = x^2}$$のグラフの描画をします。

以下のコードで最小値を計算してプロットしていきます。

  ans = 100
  lr = 0.1
  for i in range(30):

        diff = differential(objective, ans)
        ans -= lr * diff

        line, = plt.plot(np.sqrt(ans), ans, color='m', marker='X')
        plots.append([line])
    

diff = differential(objective, ans)
ans -= lr * diff

"ans"は初期値が"100"からスタート。これを"objective()”に代入して微分(傾き)を計算します。

次に"lr"は変化量、学習率などと呼ばれて、この変化量ごとに数値(ans)をずらしていきます。これを30回ループします。

for i in range(30):

グラフの描画については以下で行っています。

line, = plt.plot(np.sqrt(ans), ans, color='m', marker='X')
        plots.append([line])

これで"ans"のグラフの位置が"0"に向かって進んいて間隔も徐々に小さくなっているのがグラフを見るとわかります。

最初の1点目の数字を出してみます。

   diff = differential(objective, ans)
    ans -= lr * diff
    line, = plt.plot(np.sqrt(ans), ans, color='m', marker='X')
    plots.append([line])

    # for i in range(30):
    #     diff = differential(objective, ans)
    #     ans -= lr * diff
    #     line, = plt.plot(np.sqrt(ans), ans, color='m', marker='X')
    #     plots.append([line])
   
    plt.show()

    print( ans)
    print(np.sqrt(ans))

for  in部分をコメントアウトして実行し、 ansとp.sqrt(ansの値を出してみます。

79.9999899987597
8.944271350912812

ループでは"ans -= lr * diff"の連続で数値が減少しているのがわかります。

この記事が気に入ったらサポートをしてみませんか?