見出し画像

ML 千本ノック: PyTorch で画像を見分ける(CIFAR10)

この連載では、機械学習フレームワークのサンプルコードを毎回1つずつピックアップして実行していきます。
その過程で得られたノウハウや考え方について、簡潔にまとめていきます。

今回のお題は「PyTorch で画像を見分ける」です。
データセットとして、MNIST の代わりに CIFAR10 を使います。
入力は 32x32 のカラー画像です。
過去の記事と同様に CNN による画像の classification(分類問題)で解決しますので、本質的には同じ問題を解いていることになります。
コードのオリジナル版はこちらです。
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

Preparation & Preprocessing

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
preprocess_x = transforms.Compose([
   transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
   datasets.CIFAR10(root="./data", train=True, download=True, transform=preprocess_x),
   batch_size=16, shuffle=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
   datasets.CIFAR10(root="./data", train=False, download=True, transform=preprocess_x),
   batch_size=16, shuffle=False, num_workers=2
)
classes = train_loader.dataset.classes

注目すべきポイント:

・ torch.utils.data.DataLoader() 引数の num_workers で data loading に使われる CPU コア(?)の個数を指定できる
 → 引数 worker_init_fn を使えば、CPU コアごとの初期化ハンドラーの登録も可能
・ transforms.Normalize() の2つ目のインスタンス引数(=分散)は適当に決定した値であろう(未確認)
・ 各クラスのラベル文字列は、datasets.CIFAR10() インスタンスもしくは DataLoader() インスタンスから取得できる
 → オリジナルコードでは自前で文字列配列を準備していた

Modeling

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, len(classes))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net().to(device)

注目すべきポイント:

・ max-pooling に F.max_pool2d() ではなく nn.MaxPool2d() を使用している
・ forward() メソッドの計算グラフ内で self.pool インスタンスを2度使用している
・ forward() メソッドの CNN から FCN への変換に torch.flatten() ではなく torch.Tensor.view() を使用している

Training

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
with tqdm(range(1, 16)) as progress:
    for epoch in progress:
        model.train()
        running_loss = []
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss.append(loss.item())
        progress.write(str({"loss": sum(running_loss[-100:], 0.0) / len(running_loss[-100:]), }))
print("Finished Training")
# torch.save(model.state_dict(), PATH)

注目すべきポイント:

・ loss function は分類問題で定番の cross entropy
・ model.train() によるトレーニングモードへの移行を忘れずに
 → nn.Dropout2d() を使っていないせいか、オリジナルコードでは抜けていた
・ optimizer には momentum ありの SGD が指定されている
 → "momentum=0.9" は一般によく見られる指定
 → Training Log を見ると、"lr=0.001" は低すぎな印象を受ける(未検証)

可読性に関すること:

・ loss function を criterion という変数で表現している
・ リストの平均(実数)を取るのに、"sum(vec, 0.0) / len(vec)" というイディオムを使用した
 → 分子の 0.0 は演算結果を実数にするため

Evaluating

# model = Net().to(device)
# model.load_state_dict(torch.load(PATH))
model.eval()
correct = 0.0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        _, y_pred = torch.max(outputs.data, 1)
        correct += (y_pred == y).sum().item()
print("Test Accuracy: {}".format(correct / len(test_loader.dataset)))

注目すべきポイント:

・ 15 epoch で "Test Accuracy": 0.657 となり、それなりの精度に達している
・ model.eval() による評価モードへの移行を忘れずに
 → オリジナルコードでは抜けていた

Evaluating、追加

オリジナルコードにならってクラス別の accuracy を算出してみると、以下のようにかなりのバラつきがあった。

model.eval()
class_total, class_correct = [0.0] * len(classes), [0.0] * len(classes)
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        _, y_pred = torch.max(outputs, 1)
        c = (y_pred == y).squeeze()
        for i, l in zip(range(4), y):
            class_total[l] += 1
            class_correct[l] += c[i].item()
for i in range(len(classes)):
    print("Accuracy of class {:10s} : {:.4f}".format(classes[i], class_correct[i] / class_total[i]))
Accuracy of class airplane   : 0.6784
Accuracy of class automobile : 0.8097
Accuracy of class bird       : 0.4640
Accuracy of class cat        : 0.3673
Accuracy of class deer       : 0.5741
Accuracy of class dog        : 0.6875
Accuracy of class frog       : 0.7851
Accuracy of class horse      : 0.6589
Accuracy of class ship       : 0.7814
Accuracy of class truck      : 0.7769

F.max_pool2d() と nn.MaxPool2d() の違いについてもう少し

F.max_pool2d() は torch.nn.functional カテゴリーの関数です。
torch.nn.functional カテゴリーの関数は、入力と出力が torch.Tensor インスタンスとなっているのが共通点です。
そのため基本的には nn.Module() の forward() メソッド内でよく使われます。

nn.MaxPool2d() は torch.nn カテゴリーのクラスです。
torch.nn カテゴリーのクラスはメンバー変数にインスタンス引数を格納し、それらを forward() メソッド内で参照しているのが共通点です。
そのため基本的には nn.Module() のコンストラクター内でよく使われます。

参考までに nn.MaxPool2d() のソースコードを引用します↓

class MaxPool2d(_MaxPoolNd):
    def forward(self, input):
        return F.max_pool2d(input, self.kernel_size, self.stride,
                            self.padding, self.dilation, self.ceil_mode,
                            self.return_indices)

torch.nn カテゴリーのクラスには学習パラメーターを持つものと持たないものがありますが、ここで述べた nn.MaxPool2d() は後者です。
学習パラメーターを持たない torch.nn クラスの存在意義は、nn.Sequential() のような計算グラフの記述、つまり forward() メソッドの記述を自動化する余地が生まれる点でしょう。
forward() メソッドの引数を torch.Tensor インスタンスのみとして それ以外の引数を不要とするために、torch.nn.functional をラッピングしたクラスが準備されているんだと思います。

torch.Tensor.view() と nn.Flatten() の違いも基本的には同様です。
"torch.Tensor.view(..., end_dim=-1)" をラッピングしたものが nn.Flatten() です。

もう2つ同様の関数として torch.flatten() と torch.Tensor.flatten() がありますが、特に前者の存在意義についてはよく分かりません。(情報求む)

Appendix: Training Log

  7%|▋         | 1/15 [00:33<07:46, 33.35s/it]{'loss': 1.7024931907653809}
 13%|█▎        | 2/15 [01:06<07:14, 33.42s/it]{'loss': 1.4613035637140275}
 20%|██        | 3/15 [01:41<06:43, 33.62s/it]{'loss': 1.344044024348259}
 27%|██▋       | 4/15 [02:14<06:10, 33.71s/it]{'loss': 1.2590252596139908}
 33%|███▎      | 5/15 [02:48<05:37, 33.75s/it]{'loss': 1.2261212736368179}
 40%|████      | 6/15 [03:22<05:03, 33.70s/it]{'loss': 1.1608078479766846}
 47%|████▋     | 7/15 [03:56<04:29, 33.69s/it]{'loss': 1.0597871035337447}
 53%|█████▎    | 8/15 [04:29<03:55, 33.67s/it]{'loss': 1.0947697323560714}
 60%|██████    | 9/15 [05:03<03:21, 33.63s/it]{'loss': 0.960875992178917}
 67%|██████▋   | 10/15 [05:36<02:48, 33.66s/it]{'loss': 0.8886493289470673}
 73%|███████▎  | 11/15 [06:10<02:14, 33.64s/it]{'loss': 0.9344993713498115}
 80%|████████  | 12/15 [06:44<01:40, 33.62s/it]{'loss': 0.890320642888546}
 87%|████████▋ | 13/15 [07:18<01:07, 33.74s/it]{'loss': 0.8351841151714325}
 93%|█████████▎| 14/15 [07:51<00:33, 33.66s/it]{'loss': 0.8004632332921028}
100%|██████████| 15/15 [08:24<00:00, 33.66s/it]{'loss': 0.7912632083892822}
Finished Training


この記事が気に入ったらサポートをしてみませんか?