見出し画像

機械学習用Template準備-3(datamodule)

はじめに 

 前回前々回でHydraの基本的な使い方と、全体像に関して説明しました。今回は、Pytorch-lightningを使用しdatamoduleを構築していこうと思います。Jetson Xavier NXにPytorch-lightningをインストールするのは少しコツが必要なのでインストールがまだの方はこちらをご参照ください。

基本構成

 今回はCIFA10を使用するdatamoduleを設計します。MNISTでも良いですが、最近MNISTのデーターセットのサーバーが変わったらしくDownload時に頻繁にERRORが発生するため今回は安定しているCIFA10を選びました。また、DVC等を使用し、学習Dataを別管理する方法もありますがHydraが持つ機能を活用すれば似たような機能を実装できるので今回は使用していません。すでにCIFA10をダウンロード済みの方はそちらをお使いください。

 階層構造は以下となります。 

./src/datamodule/dataset.py

 同様に、hyperparameter管理用の.yamlの階層も示します。

./config/datamodule/datamodule.yaml

 Pytorch-lightningとHydraを組み合わせることによりdatamodule部分を分離し管理・運用ができるようになるので非常に便利です。dataが存在するフォルダー位置をhyperparameterとすることで、dataの場所も別で管理できGitでの管理も容易にできます。
 dataset.pyでdatamoduleの本体を実装し、datamodule実装に必要なhyperparamterをyamlで管理するという流れになります。

datamodule.yamlの実装

 実際の中身を見ていきましょう。Folder構成等は適宜変更してください。

# @package datamodule
_target_: src.datamodule.dataset.Cifa10DataModule
data_dir: /home/jetson/project/ml/datasets/CIFA10
download: True
batch_size: 64

 1行目はhydraの宣言分で、packageとしてこのdatamoduleを使用する事を宣言しています。この宣言が無いと"."による参照ができなくなるため必要です。
 2行目はinstanceとして読み込むclass情報を宣言しています。後ほど示す。src内で実装するclass Cifa10dataModuleの名前が書かれています。基本構成のところで示したFolderの階層が”.”により接続されています。
 3行目以降はhyperparameterとなっています。CIFA10のdatasetをダウンロードし保存しておく場所、downloadを有効かするかのOption、及びbatch_sizeが書かれております。

dataset.pyの実装

 Pytorch-lightningを使用することでdataset部分を個別にかつ、Simpleに実装することができます。細かいところは公式のサンプル等をご参照ください。ここでは大枠部分を記載します。最低限の実装なので他の処理が必要な場合は適宜追加してください。

import pytorch_lightning as pl
import multiprocessing


from torchvision.datasets import CIFAR10
from torchvision import transforms


from torch.utils.data import random_split, DataLoader

class Cifa10DataModule(pl.LightningDataModule):
   def __init__(self, data_dir: str, download: bool, batch_size: int):
       super().__init__()
       self._data_dir: str = data_dir
       self._download: bool = download
       self._batch_size: int = batch_size
       # get an available cup number
       self._num_of_cpus: int = multiprocessing.cpu_count()
       # transform
       self._transform = transforms.Compose([
           transforms.ToTensor(),
       ])
       
   def prepare_data(self):
       CIFAR10(self._data_dir, train=True, download=self._download)
       CIFAR10(self._data_dir, train=False, download=self._download)
   def setup(self, stage=None):
       if stage=='fit' or stage is None:
           train_ds = CIFAR10(self._data_dir, train=True, download=False, transform=self._transform)
           self._train_ds, self._val_ds = random_split(train_ds, [45000, 5000])
       elif stage=='test' or stage is None:
           self._test_ds = CIFAR10(self._data_dir, train=False, download=False, transform=self._transform)
   
   def train_dataloader(self):
       return DataLoader(self._train_ds, self._batch_size, True, num_workers=self._num_of_cpus)
   def val_dataloader(self):
       return DataLoader(self._val_ds, self._batch_size, False, num_workers=self._num_of_cpus)
   def test_dataloader(self):
       return DataLoader(self._test_ds, self._batch_size, False, num_workers=self._num_of_cpus)

 非常にシンプルな実装になっております。multiprocessingを使用し、使用可能なCPU数を取得しDataloaderのnum_workersに渡しています。hyperparameterで定義しても良いですが、datamodule以外ではあまり使用しないのでここにまとめています。
 この実装で重要な点は__init__の引数が前述のyaml内で書かれているhyperparameter名と一致している必要があります。一致していない場合はerrorで止まります。

まとめ

 今夏はHydraとPytorch-lightningを組み合わせてdatamoduleを実装しました。pytorch-lightningを使用する事で非常にシンプルでかつ管理しやすい書き方ができます。


いいなと思ったら応援しよう!

USAEng
アメリカSilicon Valley在住のエンジニアです。日本企業から突然アメリカ企業に転職して気が付いた事や知って役に立った事を書いています。