【Go・Deep Learning その③】 ゼロから作るdeep learning 3を復習 ーMNISTを学習する

これまで

今回

今回はざっくり全体像を説明しながら、MNIST学習していこうと思います。

MNISTをダウンロード⬇︎

とりあえずMNISTデータを取得しちゃいますかぁ
下の項目をまずダウンロードするコード書く📝

  •  http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz

  • http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

  • http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz

  • http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

const (
	TRAIN_URL      = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
	LABEL_URL      = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
	TEST_TRAIN_URL = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
	TEST_LABEL_URL = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"

	LOCAL_DATA_PATH        = "data"
	LOCAL_TRAIN_PATH       = LOCAL_DATA_PATH + "/train-images-idx3-ubyte.gz"
	LOCAL_LABELS_PATH      = LOCAL_DATA_PATH + "/train-labels-idx1-ubyte.gz"
	LOCAL_TEST_TRAIN_PATH  = LOCAL_DATA_PATH + "/t10k-images-idx3-ubyte.gz"
	LOCAL_TEST_LABELS_PATH = LOCAL_DATA_PATH + "/t10k-labels-idx1-ubyte.gz"
)

func downloadMnist() error {
	urls := []string{TRAIN_URL, LABEL_URL, TEST_TRAIN_URL, TEST_LABEL_URL}
	localPaths := []string{LOCAL_TRAIN_PATH, LOCAL_LABELS_PATH, LOCAL_TEST_TRAIN_PATH, LOCAL_TEST_LABELS_PATH}
	if !util.Exists(LOCAL_TRAIN_PATH) ||
		!util.Exists(LOCAL_LABELS_PATH) ||
		!util.Exists(LOCAL_TEST_TRAIN_PATH) ||
		!util.Exists(LOCAL_TEST_LABELS_PATH) {
		if !util.Exists(LOCAL_DATA_PATH) {
			if err := os.Mkdir(LOCAL_DATA_PATH, os.ModePerm); err != nil {
				return err
			}
		}

		for i := 0; i < len(urls); i++ {
			log.Info("Download", urls[i], "...")

			response, err := http.Get(urls[i])
			if err != nil {
				return err
			}

			body, err := ioutil.ReadAll(response.Body)
			if err != nil {
				return err
			}

			log.Info("Save", localPaths[i])
			if err := ioutil.WriteFile(localPaths[i], body, os.ModePerm); err != nil {
				return err
			}

		}
	}
	return nil
}


GoMNISTというライブラリ📚

mnistのデータを簡単に使いたい!😇
GoMNISTというのがGitHubにありました。
MNISTデータを学習データとラベルデータに読み込み・変換をおこなってくれます。
使うことをお勧めいたします。🙇‍♂️🙇‍♂️🙇‍♂️

ざっくりしたMainコード

lossAcc := dgraph.NewLossAcc("train", infragraph.New())
testLossAcc := dgraph.NewLossAcc("test", infragraph.New())

maxEpoch := 5 // 繰り返す回数
batchSize := 100 // バッチサイズ
hiddenSize := 1000 // 隠れ層

gray := func(m core.Matrix) core.Matrix { return m.CopyDivFloat(255) }
trainSet := datasets.NewMnist(dz.Train(true), dz.TransformData(gray)) // MNIST学習用
testSet := datasets.NewMnist(dz.Train(false), dz.TransformData(gray)) // MNISTテスト用
trainLoader := dz.NewDataLoader(trainSet, batchSize)
testloader := dz.NewDataLoader(testSet, batchSize, dz.DShuffle(false))

// 多重パーセプトロン 28*28 -> 1000 -> 1000 -> 10
model := models.NewMLP([]int{hiddenSize, hiddenSize, 10}, models.ActivationFunc(func(v ...dz.Variable) dz.Variables { return fn.NewRelu().Apply(v...) }))

// 最適化手法(Adam)
optim := optimizers.NewAdam().Setup(model)

for i := 0; i < maxEpoch; i++ {
    sumLoss, sumAcc := 0.0.
    bar := pb.StartNew(trainLoader.Len())
    for trainLoader.Next() {
        bar.Increment()
        x, t := trainLoader.Read()           // データの読み込み
        y := model.Apply(x).First()          // モデルの実行
        loss := fn.SoftmaxCrossEntropy(y, t) // 正解との誤差
        acc := fn.Accuacy(y, t)              // 正解率

        model.ClearGrads()
        loss.Backward(dz.RetainGrad(true)) // 微分微分微分
        optim.Update()                     // 重みの更新

        sumLoss += loss.Data().At(00) * float64(t.Data().Len())
        sumAcc += acc.Data().At(00) * float64(t.Data().Len())
    }

    bar.Finish()
    lossAcc.Add(sumLoss/float64(trainSet.Len()), sumAcc/float64(trainSet.Len()))
    fmt.Println("epoch", i+1" train: ",
        "loss", sumLoss/float64(trainSet.Len()),
        "accuracy", sumAcc/float64(trainSet.Len()))
    sumLoss, sumAcc = 00

    // ここからテスト用
    core.NoGrad(func() error {
        for testloader.Next() {
            x, t := testloader.Read()
            y := model.Apply(x).First()
            loss := fn.SoftmaxCrossEntropy(y, t)
            acc := fn.Accuacy(y, t)
            sumLoss += loss.Data().At(00) * float64(t.Data().Len())
            sumAcc += acc.Data().At(00) * float64(t.Data().Len())
        }
        testLossAcc.Add(sumLoss/float64(testSet.Len()), sumAcc/float64(testSet.Len()))

        fmt.Println("epoch", i+1" test: ",
            "loss", sumLoss/float64(testSet.Len()),
            "accuracy", sumAcc/float64(testSet.Len()))
        return nil
    })
}

if err := lossAcc.Plot(); err != nil {
    panic(err)
}
if err := testLossAcc.Plot(); err != nil {
    panic(err)
}

よくわからない解説❓

サクッとコードを理解してみます。

  1. モデルを用意

    • 最初の重みは正規分布に従うランダムな値🔢

    • ちなみに重みの行列はインプットされる行列に形が依存する

    • 今回のモデルのパーセプトロン: 28*28* -> 1000 -> 1000 -> 10 

      1. (row: N, col: 784)  ✖️ (row: 784, col: 1000) + bias

      2. (row: N, col: 1000) ✖️ (row: 1000, col: 1000) + bias

      3. (row: N, col: 1000) ✖️(row: 1000, col: 10) + bias

      4. (row: N, col: 10)

  2. 最適化手法を定義する!?

  3. バッチサイズに沿ってデータを切り取る

  4. モデルの実行(「1.モデルを用意」のやつ)

  5. 結果からlossを求める(答えからのズレ)

  6. lossから関数達を通して微分する

  7. 最適化手法(Adam)でlossの微分結果を元に重みを更新

大体こんなことを繰り返して正解率を高め、損失を小さくしていく


モデルの可視化

AccとLossとバグじゃね?🐛🐛🐛🐛🐛🐛🐛🐛🐛

Train
左: Acc: 0.9757166666666667
右: Loss: 0.347411994067936

Test
左: Acc: 0.9666
右: Loss:  0.5619361522344886

lossが大きい気もします。
Accも低い気がします。
おかしいですね。。。
これはコード見直して何がダメか調べないとダメですかね?
頑張ります。

ソースコード

まとめ✅

今回はざっくりではありますが、MNISTの学習をやりました。🙇‍♂️
関数-> 損失->微分->最適化(更新)のイメージをしっかりと持つと迷子にならないと思いますね。。。。。。。。。。。。。。。。。
次回はCNNあたりになりそうです。
次こそいい感じの記事を、書きたいのに書きたいのにこうな

今回の復習にあたって困った事は、
Gonumが予想よりも遅かったことです。
(文句言うな俺。ごめんなさい。🙇‍♂️)

そこでOpenBlasを使用してみたら半分ぐらいの時間で終わりました。
(でも時間かかる。GPU強すぎ)
ってことでGonumでOpenBlasを使う方法をざっと紹介します。

おまけ🍭

インストール方法

参考

go get -u -t gonum.org/v1/gonum/...
go get -d gonum.org/v1/netlib/...
git clone https://github.com/xianyi/OpenBLAS
cd OpenBLAS
make
make install
export LD_LIBRARY_PATH=/opt/OpenBLAS/lib/:$LD_LIBRARY_PATH
CGO_LDFLAGS="-L/opt/OpenBLAS/lib -lopenblas" go run main.go

次にコード

参考

import (
	"gonum.org/v1/gonum/blas/blas64"
	"gonum.org/v1/netlib/blas/netlib"
)

func init() {
	blas64.Use(netlib.Implementation{})
}

ではまた👋


この記事が気に入ったらサポートをしてみませんか?