見出し画像

Whisperの音声文字起こしの修正に第二候補以降が使えないか試してみる

OpenAIが2022/9に公開した音声文字起こしモデルのWhisperですが、日本語音声に対しても性能がかなり高いようで、個人的にかなり期待があります。

Whisperの各サイズのモデルで長文の文字起こし試した記事を書いてくださっている方がおり、特にlargeモデルはすごく正確な結果になっていることがわかります。

実応用では、より正確な文字起こし結果を残したいようなユースケースがあり、文字起こし結果を見ながら手で修正することがあると思います。

例えば「汎化性能の高いモデルを作る」という音声をsmallモデルで文字おこしするとこうなります。

ここで、「半花」の部分を修正するとき、手で打ち直さなくても、第二候補などで「汎化」が取得できていると候補から選択するだけで作業できるはずです。

WhisperのDecoding方法を眺める

WhisperのDecodingは、GPTなどの言語モデルのDecodingと同じ方式で、前から1tokenずつ順番にDecodingする形式です。

各ステップでは、次のtokenのlogitがvocab内の全tokenに対して得られていて、デフォルトではもっとも確率の高いtokenを採用するGreedy Decodingが使われています。

実装としては、decoding.pyのGreedyDecoder.updateで全トークンのlogitsからargmaxでnext_tokenを選んでいます。
https://github.com/openai/whisper/blob/main/whisper/decoding.py#L473

class GreedyDecoder(TokenDecoder):

#(略)

    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
        temperature = self.temperature
        if temperature == 0:
            next_tokens = logits.argmax(dim=-1)

#(略)

TopKのトークンを確認する

各ステップでのtopk候補のtokenを見るには、このGreedyDecoderを改造するとよさそうです。
手軽に見てみるために、このGreedyDecoder.update関数をモンキーパッチをして、リスト変数seq_logitsに各ステップでのlogitsが格納されるようにします。
update関数を書き換えたのは、3行目の seq_logits.append(logits)のみです。

import whisper
from whisper.decoding import GreedyDecoder
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

seq_logits = []
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
    temperature = self.temperature
    seq_logits.append(logits)
    if temperature == 0:
        next_tokens = logits.argmax(dim=-1)
    else:
        next_tokens = Categorical(logits=logits / temperature).sample()

    logprobs = F.log_softmax(logits.float(), dim=-1)
    current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
    sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)

    next_tokens[tokens[:, -1] == self.eot] = self.eot
    tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

    completed = (tokens[:, -1] == self.eot).all()
    return tokens, completed

GreedyDecoder.update = update
model = whisper.load_model("small")
result = model.transcribe("data/sample_01.wav")

ライブラリのソースファイル自体を編集する必要はなく、クラスをimportしてメンバ関数を書き換える方法を使っています。
(下から三行目の GreedyDecoder.update= update の部分)

この状態でmodel.transcribeを呼ぶことで、seq_logits変数にlogitsを入れることができました。

decodingの各ステップでのlogits

logitsのTopKを単語にする

WhisperのTokenizerを使って、このseq_logitsのTopKを単語に変換します。
WhisperのTokenizerは、get_tokenizer関数で取得できます。
seq_logitsをtopkをとって、tokenizer.decodeで1tokenずつデコードしました。

from whisper.tokenizer import get_tokenizer, Tokenizer
tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")

import pandas as pd
k = 3
topk_token_ids = torch.stack([torch.argsort(logits)[0][-k:] for logits in seq_logits]).T.numpy()

topk_decoded_words = [[tokenizer.decode(idx) for idx in token_ids] for token_ids in topk_token_ids[::-1]]
pd.DataFrame(topk_decoded_words)
pd.DataFrame(topk_decoded_words)

一番上の行がtop1の結果(model.transcribeの返却)です。
2行目以降は1行目の直前トークンまでで条件づけたときの、候補topkが入っています。
分かりづらいので例を出します。2列目に2トークン目の候補として「花」「化」「火」が並んでいます。これは、1行目の1トークン目「半」の後に続く予測として、「半花」「半化」「半火」がtop1,2,3だったことを意味しています。
1文字目の候補として「汎」が出てくることを期待しましたが、出てきませんでした。これは、「汎」が予測されにくいということ以外に、tokenizerに由来する理由があります。

ところどころで出てきている�について

�はtokenがutf-8にできなかったことを意味しています。
正確には理解できていないのですが、Whisperのtokenizerは、日本語の1文字を2tokenで表現する場合があるようで、1tokenでは日本語の1文字が表現できていないことがあります。
それを無理やりprintすると�が表示されます。

例えば、今回出てくることを期待していた「汎」は、tokenizeすると2token(12800, 236)で表現されています。

12800、236は個別にdecodeしようとするとに�なりますが、つなげると「汎」にdecodeできます。

「汎」を作るのに2tokenかかるため、今回のように1tokenだけの第二候補以降を出しても「汎」は出てこないことがわかります。

BeamSearchの途中結果を使って2token以上先を見る

「汎」を出すには、次の1tokenだけの上位候補ではなく、それらの候補について、さらに1token先を見る必要があります。

次の1tokenだけであれば、デフォルトのGreedyDecoderの処理で計算されているlogitsをとっておくだけでよかったのですが、さらに次の1tokenを見るには、上位候補を末尾に追加した文字列を作って、modelをforwardさせる必要があり、追加の計算コストが発生します。

これはBeamSearchで似たようなことをやっていて、BeamSearchDecoder.updateで計算されている情報を使うとよさそうです。
https://github.com/openai/whisper/blob/main/whisper/decoding.py#L310

update関数がだいぶ長いので省略しますが、beamごとに次のtopk候補を計算している部分で、sequence変数をグローバル変数candidatesに入れるようにパッチします(candidates.append(sequence)が追加した行)

candidates = []
def beam_update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
# (略) 
               for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
                    new_logprob = (sum_logprobs[idx] + logprob).item()
                    sequence = tuple(prefix + [token.item()])
                    candidates.append(sequence)
# (略) 
BeamSearchDecoder.update = beam_update

ビームサーチを有効にするには、model.transcribeでbeam_sizeを渡します。

result_beam = model.transcribe("sample.wav", beam_size=3)

これで、candidates変数に、その時点でのtoken_idの候補が入りました。
中を見てみます(最初の4トークンはタスク名や言語などの特殊文字なので省いています)

[(cand[4:], tokenizer.decode(cand[4:])) for cand in candidates]

((30018,), '半'),
((15927,), 'ハ'),
((23141,), '�'),
((2129,), '�'),
((30018, 20127), '半花'),
((30018, 23756), '半化'),
((30018, 26171), '半火'),
((30018, 11561), '半�'),
((15927, 4824), 'ハン'),
((15927, 14121), 'ハン�'),
((15927, 32026), 'ハム'),
((15927, 20745), 'ハウ'),
((23141, 223), '繁'),
((23141, 225), '繃'),
((23141, 108), '繰'),
((23141, 233), '繋'),
…(略)

2tokenまでのbeam候補

2token目までのBeamSearchの候補を見てみると、
2tokenで1文字の「繁」「繃」「繰」「繋」が候補として見えるようになりました。
残念ながら「汎」は候補に入りませんでしたが…

デメリットとして、Beam SearchはデフォルトのGreedy Searchよりもモデルの推論が多く呼ばれるので処理が遅くなりますが、この方法なら修正候補をある程度提示できる見込みがありそうです。

この記事が気に入ったらサポートをしてみませんか?