見出し画像

ML 千本ノック: PyTorch で数字を見分ける(MNIST)

この連載では、機械学習フレームワークのサンプルコードを毎回1つずつピックアップして実行していきます。
その過程で得られたノウハウや考え方について、簡潔にまとめていきます。

今回のお題は「PyTorch で数字を見分ける」です。
前回同様 MNIST を CNN による画像の classification(分類問題)で解決しますが、今回は Keras の代わりに PyTorch を使います
コードのオリジナル版はこちらです。
→ https://github.com/pytorch/examples/tree/master/mnist

Preparation & Preprocessing

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
preprocess_x = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )),
])
additional_args = {“num_workers”: 1, “pin_memory”: True, } if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(“../data”, train=True, download=True, transform=preprocess_x),
    batch_size=64, shuffle=True, **additional_args
)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(“../data”, train=False, transform=preprocess_x),
    batch_size=1000, shuffle=False, **additional_args
)

注目すべきポイント:

torch.utils.data.DataLoader() で mini batch 生成を楽にできる
 → バッチに含まれる入力データの次元は NCHW
 → ターゲットはラベル(クラス値)のまま
・ torchvision.datasets.MNIST() の transform 引数と torchvision.transforms.Compose クラス群で preprocessing を楽にできる
 → torchvision.transforms.Normalize() インスタンス引数のマジックナンバーは、training set から算出しておいた平均・分散であろう(未確認)

可読性に関すること:

・ torch のサブモジュールは一行にまとめて import しづらい、かといって個々に import すると import 文だらけになってしまう
 → nn / F / optim 以外のサブモジュールは、これら別名空間からアクセスするのが良い
・ preprocessing はインスタンス化しておく
 → preprocess_x

Modeling

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x


model = Net().to(device)

注目すべきポイント:

・ nn.Module の継承と親クラスのコンストラクター呼び出しを忘れずに
・ nn レイヤーのコンストラクター引数には、入力次元を明示しなければならない

可読性に関すること:

・ GPU 使用時と CPU 使用時とでコードが同じに見えるようにしておく

Training

def train(model, device, loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = F.nll_loss(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        y_pred = output.argmax(dim=1, keepdim=True)
        correct += y_pred.eq(y.view_as(y_pred)).sum().item()
    return {
        “loss”: total_loss / len(loader.dataset),
        “accuracy”: correct / len(loader.dataset),
    }


optimizer = torch.optim.Adadelta(model.parameters(), lr=1.0)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
with tqdm(range(1, 16)) as progress:
    for epoch in progress:
        result = train(model, device, train_loader, optimizer)
        scheduler.step()
        progress.write(str(result))

注目すべきポイント:

・ model.train() によるトレーニングモードへの移行を忘れずに
 → nn.DropOut レイヤーの演算内容を切り替えるために必要
・ Automatic Differentiation 系のフレームワークなので(?)、微分値クリアや backprop などは明示的に呼び出しが必要 
・ 理由は定かではないが adamdelta optimizer
・ さらに learning rate に stepping スケジューラーを導入している
 → "step_size=1" なので learning rate decay と等価
・ F.nll_loss() 実質的には categorical crossentropy で算出している
 → ターゲット y を one-hot 形式に変換しておく必要が無く、ちょこっと楽

Evaluating

def test(model, device, loader):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            total_loss += F.nll_loss(output, y, reduction=“sum”).item()
            y_pred = output.argmax(dim=1, keepdim=True)
            correct += y_pred.eq(y.view_as(y_pred)).sum().item()
    return {
        “loss”: total_loss / len(loader.dataset),
        “accuracy”: correct / len(loader.dataset),
    }


result = test(model, device, test_loader)
print(str(result))
# torch.save(model.state_dict(), "mnist_cnn.pt")

注目すべきポイント:

・ 15 epoch で 'accuracy': 0.9914 となり、とても良い精度に達している
・ model.eval() による評価モードへの移行を忘れずに
 → nn.DropOut レイヤーの演算内容を切り替えるために必要
・ "with torch.no_grad()" により、スコープ内で生成される torch.Tensor が微分値を持たなくなる
 → "requires_grad=False" としてふるまう
 → 演算が軽くなる

F.nll_loss と F.log_softmax についてもう少し

リファレンスを調べたところ、F.nll_loss() のプロトタイプは次のようになっていました。

torch.nn.functional.nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

引数 target はラベル値ということになっていました。
引数 input はどうも probs(確率)を期待しているようです。
今回のサンプルコードの neural network model の最終段の activation function を確認すると、確かに F.log_softmax() が指定されています。

なお nll は Negative Log Likelihood の略でした。

類似の loss function に F.cross_entropy() というものを見つけました。

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

ここで引数 input は logits(ここでは neural network model の最終段の activation function 適用前のテンソルを指します)ということになっていました。
F.cross_entropy() のソースコードを見ると最後の処理が "return nll_loss(log_softmax(input, 1), target, ...)" となっており、F.nll_loss() のラッパーとなっていることが確認できました。

念のため検算もしてみました。
これまでの理解が正しければ、以下で出力されるテンソルの1つ目と4つ目は一致するはずです。

logits = torch.tensor([[0., 0., 1., 0., ], [-10., -1., 1., 10., ]], dtype=torch.float32)
probs = F.log_softmax(logits, dim=1)
target = torch.tensor([2, 3, ], dtype=torch.int64)
print("cross_entropy:")
print(F.cross_entropy(logits, target, reduction="none"))
print(F.cross_entropy(probs, target, reduction="none"))
print("nll_loss:")
print(F.nll_loss(logits, target, reduction="none"))
print(F.nll_loss(probs, target, reduction="none"))

実行結果:

cross_entropy:
tensor([7.4367e-01, 1.4006e-04])
tensor([7.4367e-01, 1.4006e-04])
nll_loss:
tensor([ -1., -10.])
tensor([7.4367e-01, 1.4006e-04])

検算してみて気づきましたが、出力されるテンソルの1つ目と2つ目も一致するのですね。
probs にさらに F.log_softmax() を掛けても、テンソルが変化しないことを今回初めて知りました。

logits,
F.log_softmax(logits, dim=1),
F.log_softmax(F.log_softmax(logits, dim=1), dim=1)

実行結果:

(tensor([[  0.,   0.,   1.,   0.],
        [-10.,  -1.,   1.,  10.]]),
tensor([[-1.7437e+00, -1.7437e+00, -7.4367e-01, -1.7437e+00],
        [-2.0000e+01, -1.1000e+01, -9.0001e+00, -1.4006e-04]]),
tensor([[-1.7437e+00, -1.7437e+00, -7.4367e-01, -1.7437e+00],
        [-2.0000e+01, -1.1000e+01, -9.0001e+00, -1.4006e-04]]))

Appendix: Training Log (GPU 使用時)

  7%|▋         | 1/15 [00:20<04:41, 20.10s/it]{'loss': 0.003231775644204269, 'accuracy': 0.9387}
 13%|█▎        | 2/15 [00:40<04:21, 20.10s/it]{'loss': 0.0012329982460166017, 'accuracy': 0.9767166666666667}
 20%|██        | 3/15 [01:00<04:01, 20.12s/it]{'loss': 0.0009372690077211397, 'accuracy': 0.9828}
 27%|██▋       | 4/15 [01:20<03:40, 20.09s/it]{'loss': 0.0007487517248218258, 'accuracy': 0.98645}
 33%|███▎      | 5/15 [01:40<03:20, 20.08s/it]{'loss': 0.0006731647326339347, 'accuracy': 0.9878166666666667}
 40%|████      | 6/15 [02:00<03:00, 20.06s/it]{'loss': 0.0005873075547783326, 'accuracy': 0.98875}
 47%|████▋     | 7/15 [02:20<02:40, 20.08s/it]{'loss': 0.0005524067785164031, 'accuracy': 0.9893333333333333}
 53%|█████▎    | 8/15 [02:40<02:20, 20.10s/it]{'loss': 0.0004975112415423306, 'accuracy': 0.99025}
 60%|██████    | 9/15 [03:00<02:00, 20.10s/it]{'loss': 0.0005170444033574313, 'accuracy': 0.9903333333333333}
 67%|██████▋   | 10/15 [03:20<01:40, 20.08s/it]{'loss': 0.0004677461209357716, 'accuracy': 0.9908833333333333}
 73%|███████▎  | 11/15 [03:41<01:20, 20.10s/it]{'loss': 0.0004638715137582039, 'accuracy': 0.9916333333333334}
 80%|████████  | 12/15 [04:01<01:00, 20.08s/it]{'loss': 0.0004659055080555845, 'accuracy': 0.9915166666666667}
 87%|████████▋ | 13/15 [04:21<00:40, 20.08s/it]{'loss': 0.00046367862552870064, 'accuracy': 0.9915833333333334}
 93%|█████████▎| 14/15 [04:40<00:20, 20.00s/it]{'loss': 0.0004662856282375287, 'accuracy': 0.9909833333333333}
100%|██████████| 15/15 [05:01<00:00, 20.07s/it]{'loss': 0.0004469044501214133, 'accuracy': 0.9914}


この記事が気に入ったらサポートをしてみませんか?