PyTorchで自作データセット作成
PyTorchで自前のデータの読み込み,データローダーの作成までを行ったのでまとめる.
環境
・Python 3.9.4
・PyTorch 1.9.1
用意したデータ
・読み込みたい画像のファイル名をすべて記載したテキストファイル 'filename.txt'
Aimage1.png
Bimage1.png
Cimage1.png
Dimage1.png
Aimage2.png
Bimage2.png
Cimage2.png
Dimage2.png
Aimage3.png
...
・対応する画像ファイル群
作成したいデータセット
普段,我々が目にする画像は1チャンネル画像(グレースケール画像)もしくは,3チャンネル画像(RGBカラー画像)であろう.
今回作成したいデータセットは,
・入力データ : 4チャンネル画像 x 3
・教師データ : 1チャンネル画像 x 1
が1セットとなるものである.
基本的な作成手順に関しては,その他のデータにも当てはまるだろう.
目標となるデータの形状は次である.
input data [ size = 242, slice = 3, channel = 4, width = 384, height = 384 ]
teacher data [ size = 242, width = 384, height = 384 ]
ライブラリのインポート
import torch
import torchvision
import numpy as np
from PIL import Image
画像ファイルの読み込み
全画像ファイルを読み込み,numpy.ndarray型のデータとして
LOAD_DIR = "./data/FourChannel/"
LOAD_FILE = "filename.txt"
SLICE_NUM = 3
CHANNEL = 4
IMAGE_SIZE = 384
#open file
file = open(LOAD_DIR + LOAD_FILE, "r", encoding = "utf_8")
#read first line
line = file.readline()
#a number of data
N = int(line)
#data[num][slice][channel][w][h]
data = np.zeros((N, SLICE_NUM, CHANNEL, IMAGE_SIZE, IMAGE_SIZE))
#masks[num][w][h]
masks = np.zeros((N, IMAGE_SIZE, IMAGE_SIZE))
#read second line and delete '\n'
line = file.readline()
line = line[0:-1]
for i in range(N):
for j in range(SLICE_NUM):
for k in range(CHANNEL):
#read image
data[i][j][k] = np.array(Image.open(LOAD_DIR + line))
#output log
if LOG:
print("read " + line)
#read line and delete '\n'
line = file.readline()
line = line[0:-1]
while 'ROI' in line:
if 'ROI_1' in line:
#read image
masks[i] = np.array(Image.open(LOAD_DIR + line))
#output log
if LOG:
print("read " + line)
else :
#output log
if LOG:
print("skip " + line)
#read line and delete '\n'
line = file.readline()
line = line[0:-1]
if LOG:
print("("+str(i+1)+"/"+str(N)+")")
file.close()
読み込んだデータの確認
print("name\tdata type\t\tshape")
print("data\t" + str(type(data)) + "\t" + str(data.shape))
print("masks\t" + str(type(masks)) + "\t" + str(masks.shape))
# name data type shape
# data <class 'torch.Tensor'> torch.Size([32, 4, 384, 384])
# masks <class 'numpy.ndarray'> (242, 384, 384)
データセットの作成
訓練用とテスト用のデータセットのクラスを定義
#self made transform class
class toTensor(object):
def __init__(self):
pass
def __call__(self, x):
data = x / 255
return torch.tensor(data, dtype = torch.float)
class TrainDataset(torch.utils.data.Dataset):
def __init__(self, train_data_ratio = 0.8):
self.transform = toTensor() #self made transform class
self.train_data = data[:int(N*train_data_ratio), int((SLICE_NUM-1)/2)]
self.train_masks = masks[:int(N*train_data_ratio)]
self.data_num = len(self.train_data)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.train_data[idx]
out_mask = self.train_masks[idx]
out_data = self.transform(out_data)
out_mask = self.transform(out_mask)
return out_data, out_mask
class TestDataset(torch.utils.data.Dataset):
def __init__(self, train_data_ratio = 0.2):
self.transform = toTensor() #self made transform class
self.test_data = data[int(N*train_data_ratio):, int((SLICE_NUM-1)/2)]
self.test_masks = masks[int(N*train_data_ratio):]
self.data_num = len(self.test_data)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.test_data[idx]
out_mask = self.test_masks[idx]
out_data = self.transform(out_data)
out_mask = self.transform(out_mask)
return out_data, out_mask
データセットのインスタンスの生成と確認
TRAIN_DATA_RATIO = 0.795
train_dataset = TrainDataset(TRAIN_DATA_RATIO)
test_dataset = TestDataset(TRAIN_DATA_RATIO)
print("name\t\tdata type\t\t\tlength\tinput data shape\t\output data shape\tdtype")
print("train_dataset\t"+str(type(train_dataset))+"\t"+str(len(train_dataset))+"\t"+str(train_dataset[0][0].shape)+"\t"+str(train_dataset[0][1].shape)+"\t"+str(train_dataset[0][0].dtype))
print("test_dataset\t"+str(type(test_dataset))+"\t"+str(len(test_dataset))+"\t"+str(test_dataset[0][0].shape)+"\t"+str(test_dataset[0][1].shape)+"\t"+str(test_dataset[0][0].dtype))
# name data type length input data shape output data shape dtype
# train_dataset <class '__main__.TrainDataset'> 192 torch.Size([4, 384, 384]) torch.Size([384, 384]) torch.float32
# test_dataset <class '__main__.TestDataset'> 50 torch.Size([4, 384, 384]) torch.Size([384, 384]) torch.float32
データローダーの作成
乱数のシード値を固定し,再現性を保証している.
torch.manual_seed(0)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 25, shuffle = True)
この記事が気に入ったらサポートをしてみませんか?