PyTorchで自作データセット作成

PyTorchで自前のデータの読み込み,データローダーの作成までを行ったのでまとめる.

環境

・Python 3.9.4
・PyTorch 1.9.1

用意したデータ

・読み込みたい画像のファイル名をすべて記載したテキストファイル 'filename.txt'

Aimage1.png
Bimage1.png
Cimage1.png
Dimage1.png
Aimage2.png
Bimage2.png
Cimage2.png
Dimage2.png
Aimage3.png
...

​・対応する画像ファイル群

作成したいデータセット

普段,我々が目にする画像は1チャンネル画像(グレースケール画像)もしくは,3チャンネル画像(RGBカラー画像)であろう.

今回作成したいデータセットは,
・入力データ : 4チャンネル画像 x 3
・教師データ : 1チャンネル画像 x 1
が1セットとなるものである.

基本的な作成手順に関しては,その他のデータにも当てはまるだろう.

目標となるデータの形状は次である.
input data [ size = 242, slice = 3, channel = 4, width = 384, height = 384 ]
teacher data [ size = 242, width = 384, height = 384 ]

ライブラリのインポート

import torch
import torchvision
import numpy as np
from PIL import Image

画像ファイルの読み込み

全画像ファイルを読み込み,numpy.ndarray型のデータとして

LOAD_DIR = "./data/FourChannel/"
LOAD_FILE = "filename.txt"
SLICE_NUM = 3
CHANNEL = 4
IMAGE_SIZE = 384

#open file
file = open(LOAD_DIR + LOAD_FILE, "r", encoding = "utf_8")

#read first line
line = file.readline()
#a number of data
N = int(line)

#data[num][slice][channel][w][h]
data = np.zeros((N, SLICE_NUM, CHANNEL, IMAGE_SIZE, IMAGE_SIZE))
#masks[num][w][h]
masks = np.zeros((N, IMAGE_SIZE, IMAGE_SIZE))

#read second line and delete '\n'
line = file.readline()
line = line[0:-1]

for i in range(N):

   for j in range(SLICE_NUM):

       for k in range(CHANNEL):

           #read image
           data[i][j][k] = np.array(Image.open(LOAD_DIR + line))

           #output log
           if LOG:
               print("read " + line)

           #read line and delete '\n'
           line = file.readline()
           line = line[0:-1]

   while 'ROI' in line:

       if 'ROI_1' in line:
           #read image
           masks[i] = np.array(Image.open(LOAD_DIR + line))
               
           #output log
           if LOG:
               print("read " + line)

       else :
           #output log
           if LOG:
               print("skip " + line)

       #read line and delete '\n'
       line = file.readline()
       line = line[0:-1]   

   if LOG:
       print("("+str(i+1)+"/"+str(N)+")")

file.close()   

読み込んだデータの確認

print("name\tdata type\t\tshape")
print("data\t" + str(type(data)) + "\t" + str(data.shape))
print("masks\t" + str(type(masks)) + "\t" + str(masks.shape))

# name	data type		shape
# data	<class 'torch.Tensor'>	torch.Size([32, 4, 384, 384])
# masks	<class 'numpy.ndarray'>	(242, 384, 384)

データセットの作成

訓練用とテスト用のデータセットのクラスを定義

#self made transform class
class toTensor(object):
   def __init__(self):
       pass
   def __call__(self, x):
       data = x / 255
       return torch.tensor(data, dtype = torch.float)

   
class TrainDataset(torch.utils.data.Dataset):
   def __init__(self, train_data_ratio = 0.8):
       
       self.transform = toTensor() #self made transform class
       self.train_data = data[:int(N*train_data_ratio), int((SLICE_NUM-1)/2)]
       self.train_masks = masks[:int(N*train_data_ratio)]
       
       self.data_num = len(self.train_data)
       
   def __len__(self):
       return self.data_num
   
   def __getitem__(self, idx):
       out_data = self.train_data[idx]
       out_mask = self.train_masks[idx]
       
       out_data = self.transform(out_data)
       out_mask = self.transform(out_mask)

       return out_data, out_mask
   
   
class TestDataset(torch.utils.data.Dataset):
   def __init__(self, train_data_ratio = 0.2):
       self.transform = toTensor() #self made transform class
       self.test_data = data[int(N*train_data_ratio):, int((SLICE_NUM-1)/2)]
       self.test_masks = masks[int(N*train_data_ratio):]

       self.data_num = len(self.test_data)

   def __len__(self):
       return self.data_num

   def __getitem__(self, idx):
       out_data = self.test_data[idx]
       out_mask = self.test_masks[idx]

       out_data = self.transform(out_data)
       out_mask = self.transform(out_mask)

       return out_data, out_mask

データセットのインスタンスの生成と確認

TRAIN_DATA_RATIO = 0.795

train_dataset = TrainDataset(TRAIN_DATA_RATIO)
test_dataset = TestDataset(TRAIN_DATA_RATIO)

print("name\t\tdata type\t\t\tlength\tinput data shape\t\output data shape\tdtype")
print("train_dataset\t"+str(type(train_dataset))+"\t"+str(len(train_dataset))+"\t"+str(train_dataset[0][0].shape)+"\t"+str(train_dataset[0][1].shape)+"\t"+str(train_dataset[0][0].dtype))
print("test_dataset\t"+str(type(test_dataset))+"\t"+str(len(test_dataset))+"\t"+str(test_dataset[0][0].shape)+"\t"+str(test_dataset[0][1].shape)+"\t"+str(test_dataset[0][0].dtype))
# name		    data type			            length	input data shape	        output data shape	    dtype
# train_dataset	<class '__main__.TrainDataset'>	192	    torch.Size([4, 384, 384])	torch.Size([384, 384])	torch.float32
# test_dataset	<class '__main__.TestDataset'>	50	    torch.Size([4, 384, 384])	torch.Size([384, 384])	torch.float32

データローダーの作成

乱数のシード値を固定し,再現性を保証している.

torch.manual_seed(0)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 25, shuffle = True)

この記事が気に入ったらサポートをしてみませんか?