見出し画像

PyTorch3D 入門

1. PyTorch3D

PyTorch3D」は、3Dグラフィックス向けの機械学習ライブラリです。「TensorFlow Graphics」「NVIDIA Kaolin」がTensorFlowをサポートするのに対し、「PyTorch3D」はPyTorchをサポートします。

2. 3Dグラフィックス向けの機械学習

3Dグラフィックス向けの機械学習の多くは、「2D画像」から「3D世界」の推論を行います。

画像11

学習は、「変換関数」の他に「レンダリング」も含めて誤差逆伝播を行い、最適値を見つけ出すことができます。

画像12

3D再構成」は、2D画像から3Dモデルを生成するタスクです。「入力画像」を「変換関数」で3Dモデルに変換し、それをレンダリングで2D画像に変換します。これが元の物体のシルエットに近くなるように、「変換関数」を更新します。

3. チュートリアルの内容

今回は、「球体メッシュ」を関数で変換した「予測メッシュ」が「ターゲットメッシュ」に近づくように学習します。

画像13

具体的には、以下の値を最小化します。

・chamfer_distance : 2つのポイントクラウドの距離。

また、これを最小化するだけでは、形状が滑らかにならないため、以下の「shape regularizers」を追加して滑らかにしています。

・mesh_edge_loss : バッチ内のメッシュの平均メッシュエッジ長の正規化損失。
・mesh_normal_consistency : メッシュ内の各メッシュの標準軟度。
・mesh_laplacian_smoothing : メッシュのバッチのラプラシアン平滑化目標。

4. インストール

「Google Colab」での「PyTorch3D」のインストール手順は、次のとおりです。

(1) 「Google Colab」で新規ノートを作成し、メニュー「編集→ノートブックの設定」で「GPU」を指定。

(2) 以下のコマンドを実行

!pip install torch torchvision
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

(3) バージョンを確認

import pytorch3d as p3d
print(p3d.__version__)
0.2.0

3. パッケージのインポート

パッケージをインポートします。

import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance,
    mesh_edge_loss,
    mesh_laplacian_smoothing,
    mesh_normal_consistency,
)
import numpy as np
from tqdm import tqdm_notebook

# matplotlib
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80

# デバイス
device = torch.device("cuda:0")

4. ターゲットメッシュの生成 

メッシュファイル(*.obj)を読み込んで、ターゲットメッシュを生成します。

(1) イルカのメッシュファイル(delphin.obj)の取得。

!wget https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj

(2) メッシュファイルの読み込み。

# メッシュファイルの読み込み
trg_obj = os.path.join('dolphin.obj')

(3) ターゲットメッシュの生成。

# 頂点と面とauxの取得
verts, faces, aux = load_obj(trg_obj)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)

# (0,0,0)を中心とする半径1の球にフィットするように正規化・中心化
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale

# ターゲットメッシュの生成
trg_mesh = Meshes(verts=[verts], faces=[faces_idx])

5. ソースメッシュの生成

「ソースメッシュ」を生成します。

# ソースメッシュの生成
src_mesh = ico_sphere(4, device)

6. ターゲットメッシュとソースメッシュのプロット

「ターゲットメッシュ」と「ソースメッシュ」のプロットを行います。

# メッシュのプロット
def plot_pointcloud(mesh, title=""):
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)
    fig = plt.figure(figsize=(5, 5))
    ax = Axes3D(fig)
    ax.scatter3D(x, z, -y)
    ax.set_xlabel('x')
    ax.set_ylabel('z')
    ax.set_zlabel('y')
    ax.set_title(title)
    ax.view_init(190, 30)
    plt.show()
# ターゲットメッシュとソースメッシュのプロット
plot_pointcloud(trg_mesh, "Target mesh")
plot_pointcloud(src_mesh, "Source mesh")

画像1

画像2

7. 最適化ループ

最適化ループを実行します。

# 変換関数の形状は、src_meshの頂点数と同じ
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)

# オプティマイザ
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)

Niter = 2000 # 最適化ステップの数
w_chamfer = 1.0 # chamfer loss の重み
w_edge = 1.0 # edge lossの重み
w_normal = 0.01 # mesh normal consistencyの重み
w_laplacian = 0.1 # mesh laplacian smoothingの重み
plot_period = 250 # プロット頻度
loop = tqdm_notebook(range(Niter))

chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []

%matplotlib inline

for i in loop:
    # オプティマイザの初期化
    optimizer.zero_grad()

    # メッシュの変形
    new_src_mesh = src_mesh.offset_verts(deform_verts)

    # 各メッシュの表面から5000個の点をサンプリング
    sample_trg = sample_points_from_meshes(trg_mesh, 5000)
    sample_src = sample_points_from_meshes(new_src_mesh, 5000)

    # chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)

    # edge loss
    loss_edge = mesh_edge_loss(new_src_mesh)

    # normal loss
    loss_normal = mesh_normal_consistency(new_src_mesh)

    # laplacian loss
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

    # 損失の加重合計
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian

    # 損失の出力
    loop.set_description('total_loss = %.6f' % loss)

    # プロットのための損失の保存
    chamfer_losses.append(loss_chamfer)
    edge_losses.append(loss_edge)
    normal_losses.append(loss_normal)
    laplacian_losses.append(loss_laplacian)

    # メッシュのプロット
    if i % plot_period == 0:
        plot_pointcloud(new_src_mesh, title="iter: %d" % i)

    # 最適化ステップ
    loss.backward()
    optimizer.step()

画像3

画像4

画像5

画像6

画像7

画像8

画像9

8. 損失をグラフにプロット

「損失」をグラフにプロットします。

fig = plt.figure(figsize=(13, 5))
ax = fig.gca()
ax.plot(chamfer_losses, label="chamfer loss")
ax.plot(edge_losses, label="edge loss")
ax.plot(normal_losses, label="normal loss")
ax.plot(laplacian_losses, label="laplacian loss")
ax.legend(fontsize="16")
ax.set_xlabel("Iteration", fontsize="16")
ax.set_ylabel("Loss", fontsize="16")
ax.set_title("Loss vs iterations", fontsize="16")

画像10

9. 予測メッシュの保存

最後に、「予測メッシュ」をファイル(final_model.obj)に保存します。

# 予測メッシュの頂点と面の取得
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)

# スケールを正規化して元のターゲットサイズに戻す
final_verts = final_verts * scale + center

# 予測メッシュの保存
final_obj = os.path.join('./', 'final_model.obj')
save_obj(final_obj, final_verts, final_faces)


この記事が気に入ったらサポートをしてみませんか?