見出し画像

Stable Diffusionでfp8は使えるのか?

 Stable Diffusionで、pytorch 2.1.0から実装されたfp8が利用出来る様なったらしい。一番大きいのはメモリの減少。なおGeForceの場合、fp8は40xx移行のみ対応しているらしい。

fp8とはなんなのか

 fpとは浮動小数点(floating point)をさし、8は8bitを差す。元々CPUに於いては小数点の実装は回路が複雑になるため実装そのものを避けて整数演算(INT)のみをサポートしていることが多かった。そのため小数点演算専用のチップが別途必要だった時代もある。

 AIに於いては、一般的にfp32(単精度浮動小数点 double),fp16(半精度浮動小数点 float),bf16などが使われて居る。精度が重視される用途(科学技術計算など)ではfp64(倍精度小数点)が使われるが、AIでは精度はあまり重要視されないので精度が低くてもあまり問題ない(正直、筆書もこの辺はよく理解していない。出来るのはAIの泥沼に足ツッコミたい人向けの上級サイトに誘導するぐらい)

 fp8は、f16の半分なので四分一精度浮動小数点になる。AI用のチップは精度を落としてもまだいけると削りまくっている。ここで出てくるのはfp8だが、さらに精度を落としたINT4/2/1(INT4は4bitなので、0,1,2…15もしくは-7,-6,..-2,-1,0,1,2…,7,8までしか扱えない)まで存在する(何のAIに使っているかは知らない)精度を落とすことのメリットはまずモデルのサイズが小さく出来ること、もう一つは一度のより多くの計算が出来ること(つまり、より速くなる)になる。

浮動小数点

 小数点が動くから浮動小数点と言う。説明が面倒なので次の数字を浮動小数点表記にしてみることにする。

$$
1.123 = 1.123 \times 10^0 \\
10.6 = 1.06  \times 10^1 \\
0.023 = 2.3 \times 10^-2 \\
$$

 要するに指数表示になる(電卓やExcelで見たくないおなじみのアレ)。この例は十進法だが、コンピュータは二進法なので二進法で実装される。

 そして、1.1123 の部分を仮数部、10^0の部分を指数部と呼ぶ。これ以外に符号(±を現す)1bitが必要なので、実際に使えるのは数字-1になる(fp32なら31bit)仮数部の計算は面倒なので省略。基本的には1.0-2.0(2.0は含まない)の間の数字を取り、bit数が多いほど正確さ(精度)が増す。指数部は2の乗数を示す(11bitの場合 $${ 2^{11} = 2048 }$$ 確保できるため、$${ 2^{-1022} }$$ から $${ 2^{1023} }$$までが表記できる。これは概ね10の100乗まで表記可能)指数部は、表記出来る桁数を意味する。

ぶっちゃけ、
仮想部 → 精度
指数部 → 桁数

fp32 仮数部 23bit 指数部 8bit
fp16 仮数部 10bit 指数部 5bit
bf16 仮数部 7bit  指数部 8bit
fp8 E4M3  仮数部 3bit  指数部 4bit
fp8 E5M2  仮数部 2bit  指数部 5bit

 fp8は、日常生活の計算に使えないほど精度が低い。fp8 E5M2は仮数部 2bitなので1.0, 1.25, 1.5. 1.75しか表せない。これに2の乗数をかけた数までになる。しかしAIではあまり問題無い。Stable Diffusionで使っているアルゴリズムは、精度より桁数の方が重要なので指数部だけ一致していれば良い。要するにfp16の変わりにfp8 E5M2が適用出来る可能がある(しかし、出力する画像そのものは階調が256段階あるので、どこかでfp16が必要な気がする)

 fp16は、数字一つに2byte、fp8は1byteを消費するためfp8は、fp16に比べて必要なメモリが半分になる。

 なおbf16は、AIの計算が基本的に桁数しか影響しないところに着目し、指数部のみをfp32に合わせた物になる。Stable Diffusionでは使っていない。

Stable Diffusionにfp8は適用出来るのか?

 Stable Diffusionが使っているAIのアルゴリズムは大きく3つ、Transfomer(CLIP)(文書の抽象化)、diffusion(画像生成)、VAE(細部補間)になる。うちVAEはfp32にしないと桁あふれする可能性があり今もfp32が使われることが多い(VAEをfp16に設定すると真っ黒な絵が出るため基本fp32が使われる。VAEにfp16を適用するにはfp16用に作られたVAEモデルがいる)残り二つはnVideiaによればTransformerもdiffusionもfp8やINT8でも動くらしい(……と言うかTensorRT Extension for Stable Diffusion Web UIじゃないかよ)

 しかし、Stable Diffusionの実装系は、実際の計算は現時点(2024/05/27)では、fp16で行っている模様。

fp8のメリット

 一つ目は使用するメモリがfp16の半分になる点で、VRAMがシビアな環境で顕著に現れると思われる。SD1.5のcheck pointはfp16で2GBあるため、fp8にすると1GB(実際には0.8-0.9GBらしい)減らせる。つまり1GB他の用途に使えることになる、SD XLの場合、概ね6GBなので3GB(実際には2.5-2.8GBらしい)減らせる。つまり(モデルサイズ - VAEサイズ)*0.8 / 2 ぐらい減りそう。

 もう一つのメリットは、計算そのものに適用した場合、fp8はfp16にくらべて演算ユニットが倍になる為、速度が最大倍になる可能性がある(ベンチマークスコアもfp8はfp16の大体倍)

 しかし、実際の結果はどうなるかと言うと、計算は自体はfp16で行っており、むしろfp8から fp16の変換操作が加算されるためfp8を使うとfp16より遅くなる。

 つまり現状(2024/5/27)、Stable Diffusionのfp8はVRAMに展開するモデルのサイズを節約できるだけになる。これはVRAMがシビアな環境(8GB環境でSD XLを使うケース)が一番有用そう。作業に使えるVRAMが実質倍になるので。なお、このインプリメントはロードするモデルの精度を落としているだけなので、fp8をサポートしていない古いGPUでも適用出来そう。

 なおGPT-4oは既にfp8で動いていそう(GPT-4に比べて速度が倍になっているし)

fp8のデメリット

 誤差も累積すればばかにならない点になる。実際の検証では細かい部分で誤差が出ている。モデルの段階でここまで誤差が出ると、実際の演算に適用すると酷いことになるかも。また、fp32からfp16に移行したときVAEでおきた桁あふれの問題があるのでfp8に最適化した再学習が必要になるかも。

※ TensorRT Extension for Stable Diffusion Web UIをnative実装すればよいのにね。しかしこの拡張、制限多いし不安定だからな……。


この記事が気に入ったらサポートをしてみませんか?