PyTorchでDeep Learning実装。- 準備
実装してみます。有名なMNISTを使います。まずライブラリ。
import torch
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
そしてネットワークを作ります。ここでネットワークの作り方をご紹介。
公式サイトです。
"torch.nn.Module"は前回、基本的なことは記事にしましたが今回はそのほかの方法で"Sequential"のご紹介です。以下公式サイトより。
# Example of using Sequential
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
単純なものを簡単に実装できますね。
さあ、仕切り直して今回は"torch.nn.Module"を使って実装していきます。
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc1 = torch.nn.Linear(28*28, 1000)
self.fc2 = torch.nn.Linear(1000, 10)
def forward(self, x):
x = self.fc1(x)
x = torch.sigmoid(x)
x = self.fc2(x)
return f.log_softmax(x, dim=1)
ネットワーク構成はシンプルに
入力層(784) - 中間層(1000)
self.fc1 = torch.nn.Linear(28*28, 1000)
中間層(1000) - 出力層(10)」の3層構造とします。
self.fc2 = torch.nn.Linear(1000, 10)
中間層の活性化関数に「シグモイド(sigmoid)関数」
x = torch.sigmoid(x)
出力は確率にしたいので「ソフトマックス(softmax)関数」
return f.log_softmax(x, dim=1)
次にデータを用意しないといけないので、今回はデータセット(MNIST)を読み込みます。データローダーという形でデータを取扱します。
from torchvision import datasets, transforms
でモジュールを読み込むことで使えるようになります。詳しくは
def load_MNIST(batch=128, intensity=1.0):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([
transforms.ToTensor(),transforms.Lambda(lambda x: x * intensity)
])),batch_size=batch,shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data',train=False,transform=transforms.Compose([
transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)
])),batch_size=batch,shuffle=True)
return {'train': train_loader, 'test': test_loader}
ここでちょっとMNISTについて寄り道を、まずdatasets.MNISTでデータを取得します。
mnist_data = datasets.MNIST('~/tmp/mnist', train=True, download=True, transform=transforms.ToTensor())
data_loader = DataLoader(mnist_data,batch_size=4, shuffle=False)
可視化してみます。
data_iter = iter(data_loader)
images, labels = data_iter.next()
npimg = images[0].numpy()
npimg = npimg.reshape((28, 28))
plt.imshow(npimg, cmap='gray')
print('Label:', labels[0])
表示されました。文字認識用のデータが確かに使うことができるようになっています。
ネットワークとデータが用意ができました。
データセットについて参考になるサイトをもう一つ。
この記事が気に入ったらサポートをしてみませんか?