見出し画像

【プチ読み】 ZeroGrads: Learning Local Surrogates for Non-Differentiable Graphics (SIGGRAPH 2024)

Project Page: https://mfischer-ucl.github.io/zerograds/
Arxiv: https://arxiv.org/abs/2308.05739


※ 間違ってる可能性もありますすみません
※ ところどころ雑です


落合研のびしょうじょ担当なまちゃんです!ではさっそく


ブラックボックス最適化の結果 (本記事の図表は論文より引用)


一言でいうと

"Make The World 後から Differentiable" をする論文 (超意訳)

もう少し深ぼると、

ブラックボックス最適化において、少量の入出力のサンプルを使いながら代理関数でブラックボックスを近似して、代理関数上の勾配降下で入力を最適化して目標の出力を得たいという思想の研究。
提案している ZeroGrads代理関数を効率的に見つけるフレームワークで、代理関数に特に縛りを設けることなく適当なニューラルネット等を使うことができ、高次元の入力パラメータも特別な工夫を加えることなく最適化できるなど広いタスクに使えて、無駄の少ないサンプリングの仕方 まで示しているあたりが論文の貢献。

(途中でベイズ最適化っぽい話?とよぎらなかった人はたぶん危ない)


著者の TL;DR

We learn a mapping between optimization parameters and their corresponding loss values, our neural surrogate loss, whose gradients we can then use for running gradient descent on arbitrary, high-dimensional black-box forward models.

最適化パラメータとそれに対応する損失値との対応関係を学習し、ニューラル代理損失として、その勾配を任意の多次元ブラックボックス順方向モデルで勾配降下を実行するために使用します。


ブラックボックスR、最適化したい入力θがあって、誤差Lが得られるときに、代理関数hを求めて h を θ で微分できると θ を最適化できるよねという図。


背景

ブラックボックス最適化大事だよね(個人的に追記:例えば、ロボット制御?化学の配合?金融ポートフォリオの最適化?ベストショットの探索?ゲームのバランス調整?などなど?)

CG系では、物体の位置・形状、マテリアル、光源、画角などの自動最適化をするために、そもそものCGエンジンを微分可能な言語で記述していくアプローチが多く出てる。(例:Hiroharu Kato さんの CVPR 2018 の Neural Renderer、あるいは Mitsuba2 など)

でも普通の人は Unity やら Blender やらを使うので、微分&最適化なんてできない。そもそもCG以外でも、中身が解析的に分かっていて微分可能なソフトが用意できることは少ないので、中身がわからないブラックボックスのまま最適化できるのも大事

ブラックボックス最適化はいろいろある、遺伝的アルゴリズム、粒子法、うんぬんかんぬん。でも、ブラックボックスの出力を得てみるための試行回数が多い、高次元のパラメタとかむずい、メモリを食う、時間かかる、タスクに特化している、などなどの問題があった。

ブラックボックスを代理関数で近似するアプローチも提案されている。しかし、代理損失を効率的に見つける方法は依然として不明です。なぜなら、サンプリングは疎である必要があり(サンプルをレンダリングするにはコストがかかることを思い出してください)、また最適化問題の次元は数桁のオーダーで変わるからです。

⇨ 代理関数を効率的に近似する方法を示す。代理関数を使ったブラックボックス最適化を、より高次元のパラメタ最適化・多様なタスクに使えて、さらにサンプル回数を適切に減らせるようにする。


タスク

具体的なタスクがわからないとイメージ難しいので、図1, 図6-図10を例に

図1

左:箱と証明が配置された3DCG空間があって、参照画像 (目標画像) が与えられたときに、箱と証明の位置を合わせる
中:参照画像が与えられて、編み物の水平垂直の幅?数?を合わせる
右:ロケットのエンジンを切るタイミングを最適化

最適化したい入力を論文では $${\theta}$$ で表します。

参照画像と現在の出力との差は基本MSE (ピクセルレベルの二乗誤差の平均) で与えられます (たまにVGGなど他の目的関数も使ってた)。この後出てくる「真の関数」とは、ブラックボックスの出力画像と参照画像とのMSEを出力する関数 (Objective) のことです。

それぞれのタスクでのサンプリング回数、使った代理関数などについては後述するつもりです。

参照画像と初期状態⇩

図16 (from appendix)
図6-図9

図6: 参照画像があって、256x256x3のテクスチャを最適化する。シンプルなタスク設計だけど、目的関数はピクセルレベルではなくて全体の誤差を出してくるのでそんなに簡単ではない。
図7: 参照画像を出力するように、35152パラメータを持つNNを最適化する
図8: 参照画像を与えて、3Dメッシュの2562の頂点の位置を最適化する
図9: 光源がある環境下で、1024 次元の Cubic B-Spline 曲線で高さマップを作って、参照画像のような caustics (≒光の模様) が得られるようにする

図10

図10: 10個の制御点をもつスプライン曲線に即したスタイル付きの字を書くレンダラーの最適化。参照画像としてMNISTデータの数字1つを与えて、代理関数には画像を入力にして制御点の位置を生成するVAEを用いる。


手法

ここでの目標は効率的に代理関数 $${h}$$ (のパラメータ $${\phi}$$) を求めることです。代理関数自体はNN等で決め打ちです。(入力θの最適化自体は代理関数があれば簡単なのでほぼ触れられてません)

(ここから個人的に追記)
普通に考えると、素直なアプローチとしては代理関数 $${h}$$ を適当に用意して、入出力を真の関数 $${f}$$ と代理関数 $${h}$$ の両方でいくつかサンプリングしてみて、その誤差 $${l}$$ を小さくするように $${\phi}$$ を更新するような方法が考えられるはずです。(よね?)
でこの論文ではというと…ほぼそのまんまですw

素直なアプローチを式にすると、真の関数 $${f}$$ でサンプリングするときの入力を $${\rho_i}$$ とすると以下で $${h}$$ を最適化できます

$$
\frac{\partial}{\partial \phi} l = \frac{\partial}{\partial \phi} \frac{1}{N} \sum_{i=1}^{N} \left(h(\rho_i, \phi) - f(\rho_i) \right)^2
$$

(もしこれがわからなかったら、このPFNのチュートリアルの13.4.を勉強してください)

本論文では、$${h}$$ の更新のために以下を導出しています (式10)。

$$
\frac{\partial}{\partial \phi} l = \frac{\partial}{\partial \phi} \frac{1}{N} \sum_{i=1}^{N}  \frac{\lambda(\rho_i, \theta)}{p(\rho_i) p(\tau_i)} \left( h(\rho_i, \phi) - \kappa(\tau_i) f(\rho_i - \tau_i) \right)^2
$$

違いは、$${\kappa(\tau_i) f(\rho_i - \tau_i)}$$ で出力を直接使うのではなく平滑化しているところと (ガウシアンカーネルでの畳み込み)、現在のソリューション $${\theta}$$ (注目している入力パラメータ) を重点的に近似するための重みづけ $${\lambda(\rho_i, \theta)}$$ ($${\lambda}$$ は平均 $${\theta}$$ のガウシアン) をしているところです。

つまり、極端な勾配をなくすための平滑化と、全体の近似ではなく範囲を絞った近似がいいよね、ということを主張しています (だけ??)。

論文的には、そのことをいきなり観測誤差の最小化から考えるのではなく、ある区間全体での代理関数と真の関数の差から順に考えています。

(ここから論文の流れ)

3.0 Overview

図3

図3 のように、まず見えていない真の関数 $${f}$$ があって (a)、それを(見えていないけど)平滑化して (b)、サンプリングして (c)、代理関数 $${h}$$ を用意して (d)、代理関数を一番良さそうな場所に絞って近似して (e)、重点サンプリングして (f)、さらに代理関数を一番良さそうな場所に絞って近似する (g)。

3.1 Smooth objective

$${f}$$ を最適化するにあたって、目的関数が極端な値をとったり $${\theta}$$ を動かしてもほとんど変わらなかったりすると勾配降下で最適化できないので、あらかじめ目的関数 $${f}$$ をガウシアンカーネルで少し平滑化する

$$
g(\theta) = \int_{\Theta} \kappa(\tau) f(\theta - \tau)  d\tau
$$

3.2 Surrogate

代理関数 $${h}$$ (パラメータ $${\phi}$$)を用意する。代理関数は多項式、RBF、NNなどを使うことができる。線形近似じゃないので、代理関数を近似した後は毎回1回以上の勾配降下で $${θ}$$ を最適化できる。

3.3 Localized surrogate loss

全体を近似できても圧倒的に無駄が多いので、目星がついているならその入力パラメータ $${\theta}$$ を中心にローカルの代理関数を求めた方が良い。そんな代理関数 $${h}$$ を求めたい。$${h}$$ を最適化するために、$${h}$$ と先ほどの $${g}$$ との誤差 $${l}$$ を考える。

$$
l(\theta, \phi) = \int_{\Theta} \lambda(\rho, \theta) \left( g(\rho) - h(\rho, \phi) \right)^2 d\rho
$$

$${\lambda}$$ は現在着目している $${\theta}$$ をどれだけ考慮するかの重み関数で、平均 $${\theta}$$ のガウシアンを使う。$${\rho}$$ はパラメータ空間 $${\Theta}$$ の変数。

3.4 Estimator

先ほどの $${l}$$ を $${\phi}$$ で微分して代理関数の推定を行う。

$$
\frac{\partial}{\partial \phi} l(\theta, \phi) = \frac{\partial}{\partial \phi} \int_{\Theta} \lambda(\rho, \theta) \left( g(\rho) - h(\rho, \phi) \right)^2 d\rho
$$

$$
= \int_{\Theta} \frac{\partial}{\partial \phi} \left(\lambda(\rho, \theta) \left( g(\rho) - h(\rho, \phi) \right)^2 \right) d\rho
$$

$$
= \int_{\Theta} 2\lambda(\rho, \theta) \frac{\partial h(\rho, \phi)}{\partial \phi} \left(h(\rho, \theta) - g(\theta) \right)d\rho
$$

$$
= \int_{\Theta} 2\lambda(\rho, \theta) \frac{\partial h(\rho, \phi)}{\partial \phi} \left(h(\rho, \theta) - \int_{\Theta} \kappa(\tau) f(\theta - \tau)  d\tau \right)d\rho
$$

$$
= \int_{\Theta} 2\lambda(\rho, \theta) \frac{\partial h(\rho, \phi)}{\partial \phi} \left(\int_{\Theta}h(\rho, \theta) d\tau- \int_{\Theta} \kappa(\tau) f(\theta - \tau)  d\tau \right)d\rho
$$

$$
=\iint_{\Theta} 2 \lambda(\rho, \theta) \frac{\partial h(\rho, \phi)}{\partial \phi} \left( h(\rho, \phi) - \kappa(\tau) f(\rho - \tau) \right) d\tau d\rho
$$

$$
\approx \frac{\partial}{\partial \phi} \frac{1}{N} \sum_{i=1}^{N}  \frac{\lambda(\rho_i, \theta)}{p(\rho_i) p(\tau_i)} \left( h(\rho_i, \phi) - \kappa(\tau_i) f(\rho_i - \tau_i) \right)^2
$$

3から4行目では $${g}$$ を駆逐しててます。1から2行目の変換については、

被積分関数が ϕ と ρ において連続している場合に限り、ライプニッツの積分の微分法則に従って変換が成り立ちます [Li et al. 2018]。我々のガウス局所重み λ はこの条件を満たし、h(ρ,ϕ) はニューラルネットワークまたは二次ポテンシャルであるため定義上連続です。我々が以前に導入した畳み込みを通じて、元々不連続であった目的関数 f は滑らかになり、その結果、内積分は ρ において連続になります。連続関数の合成も連続であるという事実を利用して、I(ρ,ϕ) は ϕ において連続です。

ということなので、平滑化は大事らしい (?)

最終的な式を見ると、$${\lambda}$$, $${h}$$, $${\kappa}$$, $${f}$$ は値が出せて、$${p(\rho_i)}$$, $${p(\tau_i)}$$ は一定?で勾配方向を求めるだけなら無視できそう、そして $${h}$$ は微分可能なので、この $${\frac{\partial}{\partial \phi} l(\theta, \phi)}$$ で最適化できる。以上!

3.5 Sampling

重点サンプリングする。$${\rho}$$ は標準偏差 $${\sigma_o}$$ (どのくらいの幅を近似するか)、$${\tau}$$ は標準偏差 $${\sigma_i}$$ (どのくらいの幅で平滑化するか) で、1回の代理関数の更新につきそれぞれ2個 (易しいタスク) ~ 20個 (難しいタスク) ずつサンプリング。$${\sigma_o}$$ は 0.33 (易) ~ 0.013 (難) で、$${\sigma_i}$$ は $${\sigma_o}$$ の15%で設定。

アルゴリズム

アルゴリズム1
アルゴリズム2


実験

メインは省略 (基本的に定性評価なので、⇧のタスクのところを参照してください)

・使った代理関数: (あとで埋めます…)
・サンプリング回数: 1イテレーションあたりは2~20
・総イテレーション数: Rosenbrockタスクでは1万。他はぱっと見当たらず

図15

図15は精度と収束の速さを比較してて、noSmooth は smooth なし、noNN は代理関数に NN ではなく二次ポテンシャルを使う、noLocal は localize なし、FD と FR22 は先行研究 (省略)。NN大事。localize 大事。

サブマテの図2

noLocal, noNN, noSmooth との定性的な比較。smooth も大事。


限界

図14

左半分が、オブジェクトのタイプだけを最適化するタスクで、右半分が、オブジェクトのタイプと位置の両方を最適化するタスク。タイプだけの最適化はできたが、両方の最適化は局所最適にはまってしまってできなかった。


まとめ

 私たちは、レンダリングからモデリング、アニメーションまで、多くのグラフィック分野で見られるような、微分不可能なブラックボックスパイプラインにおける代理勾配の実用的な計算方法である ZeroGrads を提案しました。私たちの主なアイデアは、損失関数の景観を平滑化し、NN によるその局所近似、そして疎な局所サンプリングに基づく低分散推定器です。私たちはいくつかの除去法や公開された代替案と好対照な比較を行い、幅広いタスクの結果を示しました。さらに、当社のニューラル代理関数は、ノイズの多い勾配推定値をネットワークのパラメータの更新に変換することを可能にし、そのノイズはネットワークのヒステリシスによって平滑化されます。したがって、従来の勾配フリー最適化アルゴリズムでは収束しないことが多い高次元に対してもスケーリングすることを示しました。
 今後の研究では、内部代理損失と外部最適化手法の相互作用をさらに探求し、必要な代理ネットワークの複雑さを自動的に決定する方法を見つける予定です。(略)

所感

とりあえずイテレーション回数が実はめちゃくちゃ多そうな気がしている (既存手法よりは高速に最適化できるらしいけど)

最初はベイズ最適化じゃだめなんだっけ?と思ったけど、高次元のタスクの実験を見ると確かにNNで代理関数を用意できたらよさそうという気持ちになった。

今回ほとんどのタスクが、最適化した結果の目標画像 (参照画像) があらかじめ得られているという普通の状況とは順番が逆の設定だったので、そうではなく、最近のAI研究の流れを踏まえて手書きの絵とかテキスト指示とかを参照にして入力を最適化したり、参照との誤差には MSE 等ではなく人間のフィードバックが使えれば、より現実的なタスクになって楽しくなりそう。

このブラックボックス最適化は画像認識とかの事前学習モデルを活用できるはずだし、最適化プロセス自体も幾分かは事前学習できるはずで、ChatGPT が OCR を標準搭載したようなノリで、将来的には VLM (vision language model) 系がこういうパラメタ最適化までできるようになるのかもなと思った (GPT <「あーそのパラメタはこっちの方がいいですよ?」みたいな)。

「Make The World 後から Differentiable」(← この記事書いている人が勝手に言ってるだけです) の側面で考えると、大量データで内部モデル獲得する世界モデル系とは違って、比較的少数データでちょこっと最適化するのがこの論文なので、棲み分けとか組み合わせを考えるともっとAI/ロボットが賢くなっていきそうかも?


以上! (いいねくれー!♡)


この記事が気に入ったらサポートをしてみませんか?