NeRFとGaussian Splatting~NeRFの実装編

先日、NeRFと3D Gaussian Splattingについてアルゴリズム概要を書きました

今回はNeRFについてもう少し実装面を取り上げてみます


論文

NeRFの親戚はたくさんありますが、提案論文のこちらを見ていきます

コード

論文のサポートGitHubレポジトリにtiny_nerfという、Google Colabで動く軽量なNeRFが公開されています

論文の理解には都合が良いので、こちらを読んでいきましょう

モデルの初期化 - init_model

8層の全結合層+Reluで構成されています
NeRFの提案論文はもともとそんなに凝ったモデルは使っていないので、このような感じです

  • 入力はカメラに入射した光線が通った各点の座標です

  • 出力は各点が持つ色(RGB三次元)と吸収率で計四次元のベクトル(をpositoinal encodingしたもの)です

def init_model(D=8, W=256):
    '''
    8層のパーセプトロン、出力は4値
    '''
    relu = tf.keras.layers.ReLU()
    dense = lambda W=W, act=relu : tf.keras.layers.Dense(W, activation=act)

    inputs = tf.keras.Input(shape=(3 + 3*2*L_embed))
    outputs = inputs
    for i in range(D):
        outputs = dense()(outputs)
        if i%4==0 and i>0:
            outputs = tf.concat([outputs, inputs], -1)
    outputs = dense(4, act=None)(outputs) #\vec{c}, \sigma

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model


Mildenhall et al. (2024)

光線の取得 - get_rays

論文中にo+tdという形式で書かれている光線ベクトルの定義ですね
初見だと何を言っているかわからないところもありますが、透視投影モデルであらわしたカメラ姿勢の式がベースになっています

こちらの教科書が参考になると思います

def get_rays(H, W, focal, c2w):
    '''
    H, W: 画像サイズ
    c2w: pose, camera2
    rays_o : coordinates in image plane
    rays_d : viewing direction
    '''
    i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
    dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1) #カメラ全体の回転をc2wが持っている 、xk_x + yk_yと基本発想は同じ
    rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
    return rays_o, rays_d

ボリュームレンダリング - render_rays

ざっくりいうと
1. 光線を計算
2. 学習するネットワークで推論
3. ボリュームレンダリング
を繰り返しています

def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, rand=False):

    def batchify(fn, chunk=1024*32):
        return lambda inputs : tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)

    # Compute 3D query points
    z_vals = tf.linspace(near, far, N_samples)
    if rand:
      z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

    # Run network
    pts_flat = tf.reshape(pts, [-1,3])
    pts_flat = embed_fn(pts_flat)
    raw = batchify(network_fn)(pts_flat)
    raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])

    # Compute opacities and colors
    sigma_a = tf.nn.relu(raw[...,3])
    rgb = tf.math.sigmoid(raw[...,:3])

    # Do volume rendering
    dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1)
    alpha = 1.-tf.exp(-sigma_a * dists)
    weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)

    rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2)
    depth_map = tf.reduce_sum(weights * z_vals, -1)
    acc_map = tf.reduce_sum(weights, -1)

    return rgb_map, depth_map, acc_map

モデルのtrain

これらをもとにモデルを学習します
次の良な感じですね
1. カメラが取得した画像とそれに対応づくカメラの姿勢を取ってくる
2. カメラの姿勢に紐づくカメラに入射した光線をシミュレート
3. その光線をもとにボリュームレンダリングする
4. 2~3に対して勾配を計算し、ネットワークを最適化

model = init_model()
optimizer = tf.keras.optimizers.Adam(5e-4)

N_samples = 64
N_iters = 1000
psnrs = []
iternums = []
i_plot = 25

import time
t = time.time()
for i in range(N_iters+1):

    img_i = np.random.randint(images.shape[0])
    target = images[img_i]
    pose = poses[img_i]
    rays_o, rays_d = get_rays(H, W, focal, pose) # \vec{o]+ t\vec{d}
    with tf.GradientTape() as tape:
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples, rand=True)
        loss = tf.reduce_mean(tf.square(rgb - target))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    if i%i_plot==0:
        print(i, (time.time() - t) / i_plot, 'secs per iter')
        t = time.time()

        # Render the holdout view for logging
        rays_o, rays_d = get_rays(H, W, focal, testpose)
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
        loss = tf.reduce_mean(tf.square(rgb - testimg))
        psnr = -10. * tf.math.log(loss) / tf.math.log(10.)

        psnrs.append(psnr.numpy())
        iternums.append(i)

        plt.figure(figsize=(10,4))
        plt.subplot(121)
        plt.imshow(rgb)
        plt.title(f'Iteration: {i}')
        plt.subplot(122)
        plt.plot(iternums, psnrs)
        plt.title('PSNR')
        plt.show()

print('Done')

最終的にはこんな出力が出てくるはずです:

https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynbより抜粋

(補足) positional encodingについて

前回説明しなかったので、positional encodingについて説明します

画像を出力するようなニューラルネットワークでは特に、ニューラルネットワークを最適化しても出力像の鮮明さが失われるということが良く起こります
出力像 yの座標(i, j)の画素にはx(i, j)だけでなくx(i-5, j-5)など近傍の画素からも情報が入り込んでしまうためですね

これを回避する手段はいくつかあるのですが、その一つにpositional encodingを挙げられます

positional encodingでは出力の(i, j)の画素に周囲からの情報が混ざらないよう、画素ごとに周波数の異なる正弦波を掛け算しておきます
周波数の異なる正弦波は互いに打ち消しあいますので、先ほどの例に即していうと出力画像のy(i, j)の画素に対して、入力画像のx(i, j)の画素が強く影響するように調整できる…というトリックです

今回は光線ベクトルをネットワークに入れる直前にpositional encodingをしています
上記例とは若干状況が異なりますが、同様の効果が得られると期待できます
(光線の進行方向に隣接している画素間の情報の混ざりあいが抑制できるので、結果的に出力モデルの先鋭化に寄与するはずです)

Mildenhall et al. (2020)より抜粋




この記事が気に入ったらサポートをしてみませんか?