【Go・Deep Learning その③】 ゼロから作るdeep learning 3を復習 ーMNISTを学習する
これまで
今回
今回はざっくり全体像を説明しながら、MNIST学習していこうと思います。
MNISTをダウンロード⬇︎
とりあえずMNISTデータを取得しちゃいますかぁ
下の項目をまずダウンロードするコード書く📝
http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
const (
TRAIN_URL = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
LABEL_URL = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
TEST_TRAIN_URL = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
TEST_LABEL_URL = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
LOCAL_DATA_PATH = "data"
LOCAL_TRAIN_PATH = LOCAL_DATA_PATH + "/train-images-idx3-ubyte.gz"
LOCAL_LABELS_PATH = LOCAL_DATA_PATH + "/train-labels-idx1-ubyte.gz"
LOCAL_TEST_TRAIN_PATH = LOCAL_DATA_PATH + "/t10k-images-idx3-ubyte.gz"
LOCAL_TEST_LABELS_PATH = LOCAL_DATA_PATH + "/t10k-labels-idx1-ubyte.gz"
)
func downloadMnist() error {
urls := []string{TRAIN_URL, LABEL_URL, TEST_TRAIN_URL, TEST_LABEL_URL}
localPaths := []string{LOCAL_TRAIN_PATH, LOCAL_LABELS_PATH, LOCAL_TEST_TRAIN_PATH, LOCAL_TEST_LABELS_PATH}
if !util.Exists(LOCAL_TRAIN_PATH) ||
!util.Exists(LOCAL_LABELS_PATH) ||
!util.Exists(LOCAL_TEST_TRAIN_PATH) ||
!util.Exists(LOCAL_TEST_LABELS_PATH) {
if !util.Exists(LOCAL_DATA_PATH) {
if err := os.Mkdir(LOCAL_DATA_PATH, os.ModePerm); err != nil {
return err
}
}
for i := 0; i < len(urls); i++ {
log.Info("Download", urls[i], "...")
response, err := http.Get(urls[i])
if err != nil {
return err
}
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
log.Info("Save", localPaths[i])
if err := ioutil.WriteFile(localPaths[i], body, os.ModePerm); err != nil {
return err
}
}
}
return nil
}
GoMNISTというライブラリ📚
mnistのデータを簡単に使いたい!😇
GoMNISTというのがGitHubにありました。
MNISTデータを学習データとラベルデータに読み込み・変換をおこなってくれます。
使うことをお勧めいたします。🙇♂️🙇♂️🙇♂️
ざっくりしたMainコード
lossAcc := dgraph.NewLossAcc("train", infragraph.New())
testLossAcc := dgraph.NewLossAcc("test", infragraph.New())
maxEpoch := 5 // 繰り返す回数
batchSize := 100 // バッチサイズ
hiddenSize := 1000 // 隠れ層
gray := func(m core.Matrix) core.Matrix { return m.CopyDivFloat(255) }
trainSet := datasets.NewMnist(dz.Train(true), dz.TransformData(gray)) // MNIST学習用
testSet := datasets.NewMnist(dz.Train(false), dz.TransformData(gray)) // MNISTテスト用
trainLoader := dz.NewDataLoader(trainSet, batchSize)
testloader := dz.NewDataLoader(testSet, batchSize, dz.DShuffle(false))
// 多重パーセプトロン 28*28 -> 1000 -> 1000 -> 10
model := models.NewMLP([]int{hiddenSize, hiddenSize, 10}, models.ActivationFunc(func(v ...dz.Variable) dz.Variables { return fn.NewRelu().Apply(v...) }))
// 最適化手法(Adam)
optim := optimizers.NewAdam().Setup(model)
for i := 0; i < maxEpoch; i++ {
sumLoss, sumAcc := 0., 0.
bar := pb.StartNew(trainLoader.Len())
for trainLoader.Next() {
bar.Increment()
x, t := trainLoader.Read() // データの読み込み
y := model.Apply(x).First() // モデルの実行
loss := fn.SoftmaxCrossEntropy(y, t) // 正解との誤差
acc := fn.Accuacy(y, t) // 正解率
model.ClearGrads()
loss.Backward(dz.RetainGrad(true)) // 微分微分微分
optim.Update() // 重みの更新
sumLoss += loss.Data().At(0, 0) * float64(t.Data().Len())
sumAcc += acc.Data().At(0, 0) * float64(t.Data().Len())
}
bar.Finish()
lossAcc.Add(sumLoss/float64(trainSet.Len()), sumAcc/float64(trainSet.Len()))
fmt.Println("epoch", i+1, " train: ",
"loss", sumLoss/float64(trainSet.Len()),
"accuracy", sumAcc/float64(trainSet.Len()))
sumLoss, sumAcc = 0, 0
// ここからテスト用
core.NoGrad(func() error {
for testloader.Next() {
x, t := testloader.Read()
y := model.Apply(x).First()
loss := fn.SoftmaxCrossEntropy(y, t)
acc := fn.Accuacy(y, t)
sumLoss += loss.Data().At(0, 0) * float64(t.Data().Len())
sumAcc += acc.Data().At(0, 0) * float64(t.Data().Len())
}
testLossAcc.Add(sumLoss/float64(testSet.Len()), sumAcc/float64(testSet.Len()))
fmt.Println("epoch", i+1, " test: ",
"loss", sumLoss/float64(testSet.Len()),
"accuracy", sumAcc/float64(testSet.Len()))
return nil
})
}
if err := lossAcc.Plot(); err != nil {
panic(err)
}
if err := testLossAcc.Plot(); err != nil {
panic(err)
}
よくわからない解説❓
サクッとコードを理解してみます。
モデルを用意
最初の重みは正規分布に従うランダムな値🔢
ちなみに重みの行列はインプットされる行列に形が依存する
今回のモデルのパーセプトロン: 28*28* -> 1000 -> 1000 -> 10
(row: N, col: 784) ✖️ (row: 784, col: 1000) + bias
(row: N, col: 1000) ✖️ (row: 1000, col: 1000) + bias
(row: N, col: 1000) ✖️(row: 1000, col: 10) + bias
(row: N, col: 10)
最適化手法を定義する!?
今回はAdamで!💪💪💪💪💪
簡単にいうとmomentumとSGDの良いとこどり的な?🤯
活性化関数にSigmoidではなくReluを使用(本に書いあったので🙇♂️)
バッチサイズに沿ってデータを切り取る
モデルの実行(「1.モデルを用意」のやつ)
結果からlossを求める(答えからのズレ)
lossから関数達を通して微分する
最適化手法(Adam)でlossの微分結果を元に重みを更新
大体こんなことを繰り返して正解率を高め、損失を小さくしていく
モデルの可視化
AccとLossとバグじゃね?🐛🐛🐛🐛🐛🐛🐛🐛🐛
Train
左: Acc: 0.9757166666666667
右: Loss: 0.347411994067936
Test
左: Acc: 0.9666
右: Loss: 0.5619361522344886
lossが大きい気もします。
Accも低い気がします。
おかしいですね。。。
これはコード見直して何がダメか調べないとダメですかね?頑張ります。
ソースコード
まとめ✅
今回はざっくりではありますが、MNISTの学習をやりました。🙇♂️
関数-> 損失->微分->最適化(更新)のイメージをしっかりと持つと迷子にならないと思いますね。。。。。。。。。。。。。。。。。
次回はCNNあたりになりそうです。
次こそいい感じの記事を、書きたいのに書きたいのにこうな
今回の復習にあたって困った事は、
Gonumが予想よりも遅かったことです。
(文句言うな俺。ごめんなさい。🙇♂️)
そこでOpenBlasを使用してみたら半分ぐらいの時間で終わりました。
(でも時間かかる。GPU強すぎ)
ってことでGonumでOpenBlasを使う方法をざっと紹介します。
おまけ🍭
インストール方法
go get -u -t gonum.org/v1/gonum/...
go get -d gonum.org/v1/netlib/...
git clone https://github.com/xianyi/OpenBLAS
cd OpenBLAS
make
make install
export LD_LIBRARY_PATH=/opt/OpenBLAS/lib/:$LD_LIBRARY_PATH
CGO_LDFLAGS="-L/opt/OpenBLAS/lib -lopenblas" go run main.go
次にコード
import (
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/netlib/blas/netlib"
)
func init() {
blas64.Use(netlib.Implementation{})
}
ではまた👋