tex練習

詳解ディープラーニングという本を読みました。これの内容をふんわりメモしていきます。 {\displaystyle}

単純パーセプトロン

\(x_n\) : 入力
\(w_n\) : ネットワークの重み
\(\theta\) : ニューロンが発火する閾値
\(y\) : 出力(発火する、しない) $$y = \begin{cases} 1 & (w_1x_1 +w_2x_2 + \cdots +w_nx_n \geq \theta) \\ 0 & (w_1x_1 + w_2x_2 + \cdots + w_nx_n \lt \theta) \end{cases} $$ 整理する
$$ f(x) = \begin{cases} 1 & (x \geq 0) \\ 0 & (x \lt 0) \end{cases} $$ $$ \mathbf{x} = \left(\begin{array}{cc} x_1 \\ x_2 \\ \vdots \\ x_n\end{array}\right), \mathbf{w} = \left(\begin{array}{cc} w_1 \\ w_2 \\ \vdots \\ w_n\end{array}\right) $$ $$ b = -\theta $$ とすると以下にまとまる (※\(f(x)\)はステップ関数というらしい) $$ y = f({}^t\mathbf{w} \mathbf{x} + b) $$

このモデルに誤り訂正学習法を用いて最適な重みベクトル\(\mathbf{w}\)、バイアス\(b\)を得る。

誤り訂正学習法

正しい出力を\( t \)とし、誤差\( (t-y) \)によりパラメータを修正する。
$$ \Delta \mathbf{w} = (t-y)\mathbf{x} \\ \Delta b = -(t-y) $$ $$\mathbf{w}^{(k+1)} = \mathbf{w}^{(k)} + \Delta \mathbf{w} \\ b^{(k+1)} = b^{(k)} + \Delta b $$

実装
import numpy as np
import matplotlib.pyplot as plt

rng = np.random.RandomState(123)

d = 2    # データの次元
N = 10   # 各パターンのデータ数
mean = 5 # ニューロンが発火するデータの平均値

x1 = rng.randn(N, d) + np.array([0, 0])
x2 = rng.randn(N, d) + np.array([mean, mean])
x = np.concatenate((x1, x2), axis=0)


w = np.zeros(d) # 重みベクトル
b = 0           # バイアス

def y(x):
    return step(np.dot(w, x) + b)
def step(x):
    return 1 * (x > 0)
def t(i):
    if i < N:
        return 0
    else:
        return 1

while True:
    classified = True
    for i in range(N * 2):
        delta_w = (t(i) - y(x[i])) * x[i]
        delta_b = (t(i) - y(x[i]))
        w += delta_w
        b += delta_b
        classified *= all(delta_w == 0) * (delta_b == 0)
    if classified:
        break

print('w: ', w)
print('b: ', b)
print('y(0,0) = ', y([0, 0]))
print('y(5,5) = ', y([5, 5]))

p = np.arange(-3, 8, 0.01)
q = (w[0]*p +b) / -w[1]

plt.plot(p, q)

xx, yy = np.hsplit(x, 2)
plt.plot(xx, yy, '.')
plt.ylim(-3, 10)
plt.show()
w:  [ 2.14037745  1.2763927 ]
b:  -9
y(0,0) =  0
y(5,5) =  1

f:id:pastalian:20171129214131p:plain

学習より直線\( 2.14037745x_1 + 1.2763927x_2 - 9 = 0 \)で平均が(0,0)の発火しないデータと平均が(5,5)の発火するデータを区別できるようになった。