見出し画像

ビット演算の話

10年近く前、圧縮したまま高速な検索ができるデータ構造(簡潔データ構造)が一部で流行しました。

私も業務で作ってみたのですが、処理を高速化するにあたって普段はあまり使わないようなビット演算の知識を仕入れたことを覚えています。

ビット演算の世界は奥が深く、コアな書籍が出ているお陰で手軽に(?)深淵をのぞくことが出来て面白いです。

コンパイラが優秀な現代において実用的かはさておき、少なくともパズルとしては非常にエンジニア心をくすぐられます。
今回、その中からとっつきやすいテクニックを紹介します。

ビットマップのループを高速化

01011000
↓
f(00001000)
f(00010000)
f(01000000)

右から走査して1のときだけ処理を行いたいとき、ビットの長さ分ループしなくても

for (int x = 0b01011000; x > 0; x &= x - 1) {
  f(x & -x);
}

 x & -x と x &= x - 1 を繰り返せば、1の部分だけ処理することができます。

x & -x

01011000 = x
10100111 = not(x)
10101000 = not(x)+1 = -x

つまり

00001000 = x & -x

となって、右端の1だけが残ります。

x & x-1

01011000 = x
01010111 = x-1

つまり

01010000 = x & x-1

となって、右端の1が消えます。

ビットカウントを高速化(シンプル版)

00101110 10100000 11100101 11001011
↓
16

1の数を数えたいとき、ビットの数だけループしなくても

x = 0b00101110101000001110010111001011;
x = (x & 0x55555555) + ((x >>>  1) & 0x55555555); // --- (1)
x = (x & 0x33333333) + ((x >>>  2) & 0x33333333); // --- (2)
x = (x & 0x0F0F0F0F) + ((x >>>  4) & 0x0F0F0F0F); // --- (3)
x = (x & 0x00FF00FF) + ((x >>>  8) & 0x00FF00FF); // --- (4)
x = (x & 0x0000FFFF) + ((x >>> 16) & 0x0000FFFF); // --- (5)

log(ビット数)回のビット演算で求められます。
何やら複雑そうに見えますが、やっていることは部分和をx自身に書き込んでいく分割統治です。

(1)〜(5)まで1ステップずつ見ていきます。

(1)の式

データを2ビットずつに区切り、2ビットごとの1の数を求めてxを上書きします。

00 10 11 10 10 10 00 00 11 10 01 01 11 00 10 11 = x
↓
_0 _1 _2 _1 _1 _1 _0 _0 _2 _1 _1 _1 _2 _0 _1 _2 (10進表現)
↓
00 01 10 01 01 01 00 00 10 01 01 01 10 00 01 10 = 新しいx

式を分解して眺めると、2ビット単位で並列に足し算していることがわかると思います。

00 10 11 10 10 10 00 00 11 10 01 01 11 00 10 11 = x
↓
↓ 2ビット単位で上の桁、下の桁に分ける
↓
_0 _1 _1 _1 _1 _1 _0 _0 _1 _1 _0 _0 _1 _0 _1 _1 = ((x >>> 1) & 0x55555555)
_0 _0 _1 _0 _0 _0 _0 _0 _1 _0 _1 _1 _1 _0 _0 _1 = x & 0x55555555
↓
↓足す
↓
00 01 10 01 01 01 00 00 10 01 01 01 10 00 01 10 = (x & 0x55555555) + ((x >>> 1) & 0x55555555)

(2)の式

データを4ビットずつに区切り、4ビットごとの1の数を求めてxを上書きします。

0001 1001 0101 0000 1001 0101 1000 0110 = x
↓
_0_1 _2_1 _1_1 _0_0 _2_1 _1_1 _2_0 _1_2 (10進表現)
↓
___1 ___3 ___2 ___0 ___3 ___2 ___2 ___3 (10進表現)
↓
0001 0011 0010 0000 0011 0010 0010 0011 = 新しいx

式を分解して眺めると、4ビット単位で並列に足し算していることがわかると思います。

0001 1001 0101 0000 1001 0101 1000 0110 = x
↓
↓4ビット単位で上の桁、下の桁に分ける
↓
__00 __10 __01 __00 __10 __01 __10 __01 = (x >>> 2) & 0x33333333
__01 __01 __01 __00 __01 __01 __00 __10 = x & 0x33333333
↓
↓足す
↓
0001 0011 0010 0000 0011 0010 0010 0011 = (x & 0x33333333) + ((x >>> 2) & 0x33333333)

(3)の式

データを8ビットずつに区切り、8ビットごとの1の数を求めてxを上書きします。

00010011 00100000 00110010 00100011 = x
↓
___1___3 ___2___0 ___3___2 ___2___3 (10進表現)
↓
_______4 _______2 _______5 _______5 (10進表現)
↓
00000100 00000010 00000101 00000101 = 新しいx

(4)の式

データを16ビットずつに区切り、16ビットごとの1の数を求めてxを上書きする。

0000010000000010 0000010100000101 = x
↓
_______4_______2 _______5_______5 (10進表現)
↓
_______________6 ______________10 (10進表現)
↓
0000000000000110 0000000000001010 = 新しいx

(5)の式

データ全体の1の数を求めて、xに上書きします。

0000000000000110 0000000000001010 = 新しいx
↓
_______________6 ______________10 (10進表現)
↓
_______________________________16 (10進表現)
↓
0000000000000000 0000000000010000 = 新しいx
↓
16

このように、ビット演算を駆使してビットカウントが計算できました。

ビットカウントを高速化(最適化版)

実はもっと少ない演算でビットカウントを求めることが出来ます。
実際、OpenJDKの Integer#bitCount は次のようになっていて、先程のコードよりも短くなっています。
https://github.com/openjdk/jdk/blob/jdk-21%2B35/src/java.base/share/classes/java/lang/Integer.java#L1693

# 上で説明したコード
int bitCount(int i) {
    x = (x & 0x55555555) + ((x >>>  1) & 0x55555555);
    x = (x & 0x33333333) + ((x >>>  2) & 0x33333333);
    x = (x & 0x0F0F0F0F) + ((x >>>  4) & 0x0F0F0F0F);
    x = (x & 0x00FF00FF) + ((x >>>  8) & 0x00FF00FF);
    x = (x & 0x0000FFFF) + ((x >>> 16) & 0x0000FFFF);
    return x;
}

# OpenJDKのコード(抜粋&一部書き換え)
int bitCount(int x) {
    x = x - ((x >>> 1) & 0x55555555);
    x = (x & 0x33333333) + ((x >>> 2) & 0x33333333);
    x = (x + (x >>> 4)) & 0x0f0f0f0f;
    x = x + (x >>> 8);
    x = x + (x >>> 16);
    return x & 0x3f;
}

改良1

上の桁・下の桁でそれぞれANDするのは非効率なので、足した後にANDします。
ただし、3回目からのANDしか適用できません。
1回目はローカルのビットカウントが1bitに収まらず(0〜2個)、2回目も2bitに収まらないためです(0〜4個)。

コードにすると次のようになります。

int bitCount(int i) {
    x = (x & 0x55555555) + ((x >>>  1) & 0x55555555);
    x = (x & 0x33333333) + ((x >>>  2) & 0x33333333);
    x = (x + (x >>>  4)) & 0x0f0f0f0f;
    x = (x + (x >>>  8)) & 0x00ff00ff;
    x = (x + (x >>> 16)) & 0x0000ffff;
    return x;
}

また、最終的に上16bitが無視されるため、4回目のANDは不要です。
最後のANDもビットカウントの値域(0~32)を考慮すると

int bitCount(int i) {
    x = (x & 0x55555555) + ((x >>>  1) & 0x55555555);
    x = (x & 0x33333333) + ((x >>>  2) & 0x33333333);
    x = (x + (x >>>  4)) & 0x0f0f0f0f;
    x = x + (x >>>  8);
    x = x + (x >>> 16);
    return x & 0x3f;
}

このようになりました。

改良2

実は、ビットカウントは次のような式で求めることが出来ます(32bit)。

$$
bitcount(x) = x - floor(x/2) - floor(x/4) - … - floor(x/2^{31})
$$

これをそのまま実装すると

x = x - ((x >>> 1) & 0x7ffffffff);
x = x - ((x >>> 2) & 0x3ffffffff);
x = x - ((x >>> 3) & 0x1ffffffff);
x = x - ((x >>> 4) & 0x0ffffffff);
x = x - ((x >>> 5) & 0x07fffffff);
x = x - ((x >>> 6) & 0x03fffffff);
...

処理が30回近く連なり非効率に見えますが…

xが2bitのときに話を限定すると

$$
bitcount(x) = x - floor(x/2)
$$

となって、次のように引き算1回とシフト1回で求められます。

x = x - (x >>> 1)

$$
\begin{array}{l:l:l:l}
\textbf{x} & \textbf{bitcount(x)} & \textbf{x >>> 1} & \textbf{x - (x >>> 1)} \\ \hline
00 & 00 & 00 & 00 \\
01 & 01 & 00 & 01 \\
10 & 01 & 01 & 01 \\
11 & 10 & 01 & 10 \\ \hline
\end {array}
$$

ここで元のビットカウントのコードに戻り

x = (x & 0x55555555) + ((x >>>  1) & 0x55555555)

の部分で2ビットずつビットカウントを求めていることを考えれば

x = x - ((x >>> 1) & 0x55555555);

と書けるため、(1)をこのコードに置き換えることでOpenJDKのコードが再現できました。

int bitCount(int x) {
    x = x - ((x >>> 1) & 0x55555555);
    x = (x & 0x33333333) + ((x >>> 2) & 0x33333333);
    x = (x + (x >>> 4)) & 0x0f0f0f0f;
    x = x + (x >>> 8);
    x = x + (x >>> 16);
    return x & 0x3f;
}

trailing zeros

trailing zerosは右端から0が連続する数を10進数で表した数値です。
例えば、次のような計算結果になります。

trailing_zeros(10100000) = 5
trailing_zeros(11001100) = 2
trailing_zeros(01010110) = 1
trailing_zeros(01001101) = 0

さきほどビットマップのループでは、右端から1の部分だけ抽出してループを回しましたが、

01011000
↓
f(00001000)
f(00010000)
f(01000000)

受け手の関数としてはビットマップよりもインデックス値のほうが扱いやすいことが多いと思います。

01011000
↓
f(3)
f(4)
f(6)

このような場合にtrailing zerosが使えます。

標準的なプログラミング言語では、ライブラリ側で次のような関数が用意されていますが

# Java
Long.numberOfTrailingZeros(long bits)

これを敢えて使わない場合、次のようなアルゴリズムで実装できます。

ビットカウントによる方法

シンプルな方法としては、さきほどのビットカウントを使う方法があります。

最右の1ビットを抜き出したビット列から1を引くと、
末尾まで1ビットが連続するのでビットカウントした値がそのままtrailing zerosに。

00100000 - 1 = 00011111
00000001 - 1 = 00000000
# 注:0だと正しい値にならないので別途考慮が必要

bitcount(00011111) = 5
bitcount(00000000) = 0

Log(ビット長)のオーダで簡単に計算できました。

De Brujin Sequenceによる方法

掛け算1回+シフト1回+配列アクセス1回で済ます方法があります。

De Brujin Sequence

De Brujin Sequence(ド・ブラウンと読みます)という名のマジックナンバー 00011101 を掛け算して

00100000 * 00011101 = 10100000
00000001 * 00011101 = 00011101

マジックナンバーの長さに応じて右シフトして(8 - log(8)ビット)

10100000 >>> 5 = 00000101
00011101 >>> 5 = 00000000

あらかじめ作った参照表をみると

$$
\begin{array}{l:l:l}
\textbf{ビットパターン(下3桁)} & \textbf{trailing zeros} \\ \hline
000 & 0 \\
001 & 1 \\
010 & 6 \\
011 & 2 \\
100 & 7 \\
101 & 5 \\
110 & 4 \\
111 & 3 \\ \hline
\end {array}
$$

00000101 -> 5
00000000 -> 0

と計算できます。

解説

De Brujin Sequence はグレイコードをひとまとめにした数字です。

例えば、8ビットの De Brujin Sequence は3桁のグレイコードをつなげたものになっています。

00011101 に対して、先頭から順番に3桁ずつ切り取る=5ケタの右シフト演算をすると
000  001  011  111  110  101  010  100  000
となって、8ビットの中に0〜7が折り畳まれていることがわかります。

つまり最右の1ビットをDe Brujin Sequenceに掛ける(8ビット)ということは、最右の1ビットをグレイコード=0〜7にマッピングするのと同じです。

あとはインデックスにグレイコード、値にtrailing zerosを対応させた配列を作っておけばOK。

$$
\begin{array}{l:l:l:l}
\textbf{①De Brujin Sequence} & \textbf{②最右の1ビット} & \textbf{①×②} & \textbf{先頭3桁} \\ \hline
00011101 & 00000001 & 00011101 & 000 \\
00011101 & 00000010 & 00111010 & 001 \\
00011101 & 00000100 & 01110100 & 011 \\
00011101 & 00001000 & 11101000 & 111 \\
00011101 & 00010000 & 11010000 & 110 \\
00011101 & 00100000 & 10100000 & 101 \\
00011101 & 01000000 & 01000000 & 010 \\
00011101 & 10000000 & 10000000 & 100 \\ \hline
\end{array}
$$

掛け算が高速に実行できるのであれば、ビットカウントを使うより効率的です。
ちなみに、GoのTrailingZeros関数はDe Brujin Sequenceで実装されています。

Goのbitsパッケージ(ソースコード)
https://go.dev/src/math/bits/bits.go?#L90

ビット長ごとのDe Brujin Sequence

8ビットの場合を例にとりましたが、16ビット、32ビット、64ビット版は次の通りです。

$$
\begin{array}{l:l:l}
\textbf{ビット長} & \textbf{De Brujin Sequence} & \textbf{右シフト長} \\ \hline
8 & 0x1D または 0x17 & 5 \\
16 & 0x0D2F など16種 & 12 \\
32 & 0x077CB531 など2048種 & 27 \\
64 & 0x03F79D71B4CA8B09 など 67108864種 & 58 \\ \hline
\end {array}
$$

32ビットの値はGo実装、
64ビットの値はThe Art of Computer Programming=Go実装、
De Brujin Sequenceの数や他の値の例についてはChess Programming Wikiから引用しています。

ここまで書いたけど

自分で実装することはまずありません…
現場では、該当する関数をライブラリ経由で呼んだり、CPUの拡張命令に頼るべきだと思います。

ビットカウント

先程触れましたが、標準ライブラリで実装されていたりします。
Java : Integer.bitCount()、Long.bitCount()
Go : bitsパッケージのOnesCountXX()

さらに言うと、15年前ぐらいからCPUの拡張命令(IntelであればSSE4.2のPOPCNT)としても実装されているので、
そもそも本気で高速化するなら直接POPCNTを呼べる言語を利用すべきです。

Javaに関してはJITがbitCount()をPOPCNTに置き換えてくれるようです。

trailing zeros

BMI1拡張命令のひとつ「tzcnt」命令として実装されてるので、
言語によってはライブラリが提供する関数で勝手に最適化されるかもしれません。

いいなと思ったら応援しよう!