PyTorch3D 入門
1. PyTorch3D
「PyTorch3D」は、3Dグラフィックス向けの機械学習ライブラリです。「TensorFlow Graphics」「NVIDIA Kaolin」がTensorFlowをサポートするのに対し、「PyTorch3D」はPyTorchをサポートします。
2. 3Dグラフィックス向けの機械学習
3Dグラフィックス向けの機械学習の多くは、「2D画像」から「3D世界」の推論を行います。
学習は、「変換関数」の他に「レンダリング」も含めて誤差逆伝播を行い、最適値を見つけ出すことができます。
「3D再構成」は、2D画像から3Dモデルを生成するタスクです。「入力画像」を「変換関数」で3Dモデルに変換し、それをレンダリングで2D画像に変換します。これが元の物体のシルエットに近くなるように、「変換関数」を更新します。
3. チュートリアルの内容
今回は、「球体メッシュ」を関数で変換した「予測メッシュ」が「ターゲットメッシュ」に近づくように学習します。
具体的には、以下の値を最小化します。
・chamfer_distance : 2つのポイントクラウドの距離。
また、これを最小化するだけでは、形状が滑らかにならないため、以下の「shape regularizers」を追加して滑らかにしています。
・mesh_edge_loss : バッチ内のメッシュの平均メッシュエッジ長の正規化損失。
・mesh_normal_consistency : メッシュ内の各メッシュの標準軟度。
・mesh_laplacian_smoothing : メッシュのバッチのラプラシアン平滑化目標。
4. インストール
「Google Colab」での「PyTorch3D」のインストール手順は、次のとおりです。
(1) 「Google Colab」で新規ノートを作成し、メニュー「編集→ノートブックの設定」で「GPU」を指定。
(2) 以下のコマンドを実行
!pip install torch torchvision
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
(3) バージョンを確認
import pytorch3d as p3d
print(p3d.__version__)
0.2.0
3. パッケージのインポート
パッケージをインポートします。
import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
chamfer_distance,
mesh_edge_loss,
mesh_laplacian_smoothing,
mesh_normal_consistency,
)
import numpy as np
from tqdm import tqdm_notebook
# matplotlib
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
# デバイス
device = torch.device("cuda:0")
4. ターゲットメッシュの生成
メッシュファイル(*.obj)を読み込んで、ターゲットメッシュを生成します。
(1) イルカのメッシュファイル(delphin.obj)の取得。
!wget https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj
(2) メッシュファイルの読み込み。
# メッシュファイルの読み込み
trg_obj = os.path.join('dolphin.obj')
(3) ターゲットメッシュの生成。
# 頂点と面とauxの取得
verts, faces, aux = load_obj(trg_obj)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
# (0,0,0)を中心とする半径1の球にフィットするように正規化・中心化
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale
# ターゲットメッシュの生成
trg_mesh = Meshes(verts=[verts], faces=[faces_idx])
5. ソースメッシュの生成
「ソースメッシュ」を生成します。
# ソースメッシュの生成
src_mesh = ico_sphere(4, device)
6. ターゲットメッシュとソースメッシュのプロット
「ターゲットメッシュ」と「ソースメッシュ」のプロットを行います。
# メッシュのプロット
def plot_pointcloud(mesh, title=""):
points = sample_points_from_meshes(mesh, 5000)
x, y, z = points.clone().detach().cpu().squeeze().unbind(1)
fig = plt.figure(figsize=(5, 5))
ax = Axes3D(fig)
ax.scatter3D(x, z, -y)
ax.set_xlabel('x')
ax.set_ylabel('z')
ax.set_zlabel('y')
ax.set_title(title)
ax.view_init(190, 30)
plt.show()
# ターゲットメッシュとソースメッシュのプロット
plot_pointcloud(trg_mesh, "Target mesh")
plot_pointcloud(src_mesh, "Source mesh")
7. 最適化ループ
最適化ループを実行します。
# 変換関数の形状は、src_meshの頂点数と同じ
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
# オプティマイザ
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)
Niter = 2000 # 最適化ステップの数
w_chamfer = 1.0 # chamfer loss の重み
w_edge = 1.0 # edge lossの重み
w_normal = 0.01 # mesh normal consistencyの重み
w_laplacian = 0.1 # mesh laplacian smoothingの重み
plot_period = 250 # プロット頻度
loop = tqdm_notebook(range(Niter))
chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []
%matplotlib inline
for i in loop:
# オプティマイザの初期化
optimizer.zero_grad()
# メッシュの変形
new_src_mesh = src_mesh.offset_verts(deform_verts)
# 各メッシュの表面から5000個の点をサンプリング
sample_trg = sample_points_from_meshes(trg_mesh, 5000)
sample_src = sample_points_from_meshes(new_src_mesh, 5000)
# chamfer loss
loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
# edge loss
loss_edge = mesh_edge_loss(new_src_mesh)
# normal loss
loss_normal = mesh_normal_consistency(new_src_mesh)
# laplacian loss
loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
# 損失の加重合計
loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
# 損失の出力
loop.set_description('total_loss = %.6f' % loss)
# プロットのための損失の保存
chamfer_losses.append(loss_chamfer)
edge_losses.append(loss_edge)
normal_losses.append(loss_normal)
laplacian_losses.append(loss_laplacian)
# メッシュのプロット
if i % plot_period == 0:
plot_pointcloud(new_src_mesh, title="iter: %d" % i)
# 最適化ステップ
loss.backward()
optimizer.step()
8. 損失をグラフにプロット
「損失」をグラフにプロットします。
fig = plt.figure(figsize=(13, 5))
ax = fig.gca()
ax.plot(chamfer_losses, label="chamfer loss")
ax.plot(edge_losses, label="edge loss")
ax.plot(normal_losses, label="normal loss")
ax.plot(laplacian_losses, label="laplacian loss")
ax.legend(fontsize="16")
ax.set_xlabel("Iteration", fontsize="16")
ax.set_ylabel("Loss", fontsize="16")
ax.set_title("Loss vs iterations", fontsize="16")
9. 予測メッシュの保存
最後に、「予測メッシュ」をファイル(final_model.obj)に保存します。
# 予測メッシュの頂点と面の取得
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)
# スケールを正規化して元のターゲットサイズに戻す
final_verts = final_verts * scale + center
# 予測メッシュの保存
final_obj = os.path.join('./', 'final_model.obj')
save_obj(final_obj, final_verts, final_faces)
この記事が気に入ったらサポートをしてみませんか?