【Go・Deep Learning その3.1】 ゼロから作るdeep learning 3を復習 ー学習の重みを保存してみる。

前回まで

今回

今回は主にGolangでどのようにして構造体を保存するのかをやっていきます。
あまりディープラーニングとは関係ないです。すみません。🙇‍♂️

最初に調べてみたら出てきた記事🙇‍♂️

ですが、この記事で紹介されている"encoding/binary"パッケージでは、
複雑な構造体を保存するのは無理なようです。

そこで、この記事でも紹介されている
"encoding/gob"パッケージを使って保存して行こうと思います。

愚直にソースコード


func SaveBinary(filename string, value interface{}) error {
	f, err := os.Create(filename)
	if err != nil {
		return fmt.Errorf("filed to create file: %w", err)
	}
	defer f.Close()

	if err := gob.NewEncoder(f).Encode(value); err != nil {
		return fmt.Errorf("cannnot encode matrix map%w", err)
	}
	return nil
}

func LoadBinary(filename string, fileValue interface{}) error {
	f, err := os.Open(filename)
	if err != nil {
		return fmt.Errorf("failed to open file: %w", err)
	}
	defer f.Close()

	if err := gob.NewDecoder(f).Decode(fileValue); err != nil {
		return fmt.Errorf("cannot decode matrix: %w", err)
	}
	return nil
}

使い方

filename := "test"

// 保存
st := []Sample{}
err := SaveBinary(filename, st)

// 読み込み
var result Sample
err := LoadBinary(filename, &result)

実際に学習に組み込んでみる✒️

const (
	lr         = 0.2
	maxIter    = 10000
	hiddenSize = 10
	fileName   = "test.w"
)

x := dz.NewVariable(core.NewRand(core.Shape{R: 100, C: 1}))
a1 := x.Data().CopyApply(func(f float64) float64 { return math.Sin(2 * math.Pi * f) })
a2 := core.NewRand(core.Shape{R: 100, C: 1})
y := fn.Add(dz.NewVariable(a1), dz.NewVariable(a2))

model := models.NewMLP([]int{hiddenSize, 2})
optim := optimizers.NewSGD(dz.Lr(lr)).Setup(model)

// ここ
if err := model.LoadWeights(fileName); err != nil {
	log.Warn(err)
}

for i := 0; i < maxIter; i++ {
	yPred := model.Apply(x).First()
	loss := fn.MeanSquaredError(y, yPred)

	model.ClearGrads()
	loss.Backward(dz.RetainGrad(true))

	optim.Update()

	if i%100 == 0 {
		fmt.Println(loss)
	}
}

// ここ
if err := model.SaveWeights(fileName); err != nil {
	panic(err)
}

一回目の損失関数の推移🏃‍♂️

5から0.185へと推移してますね。

二回目の損失関数の推移👬

二回目は、0.165から0.156と最初から低いようです。

とりあえず、最初っから値が小さいので、
読み込みていそう?


まとめ✅

今回は機械学習とはあまり関係ない記事でした🙇‍♂️
numpyでは簡単に行列を保存する機能があります。
Go言語ではバイナリで保存することは決して難しいことではなさそうです。

最近あまり勉強に時間を作れていないので、
残りはサクッと終わらせたいと思います。
何も有益な記事ではないですが、怒らないでくださいorz


この記事が気に入ったらサポートをしてみませんか?