WhisperKitサンプルのソースコードを読む
先日WhisperKitについて、ざっくり紹介するLTを行った:
で、紹介しきれなかったことはこちらの記事に書いた:
さらに書ききれなかったのが、サンプルのコードリーディング。
READMEには以下の2行で実装できると記載してあったのだが、
let pipe = try? await WhisperKit()
let transcription = try? await pipe!.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text
print(transcription)
本当に2行で実装してるの?というところをサンプルコードを追いかけてみた。
結論としては2行でできるのも嘘ではないが、サンプルはもっといろいろやっている。(まぁそうだろうな、という妥当な結論)
主な処理の流れ
サンプルアプリの実装はほぼすべてが ContentView.swift に書かれている。
オーディオファイルの書き起こしは、ピッカーでファイル選択したところで `handleFilePicker(result:)` メソッドが呼ばれ、その中で `transcribeFile(path:)` メソッドが呼ばれ、その中で `transcribeCurrentFile(path:)` メソッドが呼ばれる。
ここらへんからが音声認識処理の実装で、ざっくり次の3行が「音声ファイルをWhisperKitで書き起こす」処理のコア部分。
guard let audioFileBuffer = AudioProcessor.loadAudio(fromPath: path) else {
return
}
let audioFileSamples = AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
let transcription = try await transcribeAudioSamples(audioFileSamples)
loadAudio メソッド
引数に渡されたファイルパスから音声データを読み出して AVAudioPCMBuffer 型を返すメソッド。
public static func loadAudio(fromPath audioFilePath: String) -> AVAudioPCMBuffer? {
guard FileManager.default.fileExists(atPath: audioFilePath) else {
Logging.error("Resource path does not exist \(audioFilePath)")
return nil
}
var outputBuffer: AVAudioPCMBuffer?
do {
let audioFileURL = URL(fileURLWithPath: audioFilePath)
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
let sampleRate = audioFile.fileFormat.sampleRate
let channelCount = audioFile.fileFormat.channelCount
let frameLength = AVAudioFrameCount(audioFile.length)
// If the audio file already meets the desired format, read directly into the output buffer
if sampleRate == 16000 && channelCount == 1 {
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameLength) else {
Logging.error("Unable to create audio buffer")
return nil
}
try audioFile.read(into: buffer)
outputBuffer = buffer
} else {
// Audio needs resampling to 16khz
guard let buffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1) else {
Logging.error("Unable to resample audio")
return nil
}
outputBuffer = buffer
}
if let buffer = outputBuffer {
Logging.info("Audio source details - Sample Rate: \(sampleRate) Hz, Channel Count: \(channelCount), Frame Length: \(frameLength), Duration: \(Double(frameLength) / sampleRate)s")
Logging.info("Audio buffer details - Sample Rate: \(buffer.format.sampleRate) Hz, Channel Count: \(buffer.format.channelCount), Frame Length: \(buffer.frameLength), Duration: \(Double(buffer.frameLength) / buffer.format.sampleRate)s")
}
} catch {
Logging.error("Error loading audio file: \(error)")
return nil
}
return outputBuffer
}
期待するフォーマット(サンプルレート16k, 1チャンネル)でない場合に再サンプリングする実装(resampleAudio メソッド)も入っている。
resampleAudio の実装までは再掲しないが、ファイルからのオーディオ読み出し時には参考になりそう。
convertBufferToArray メソッド
AVAudioPCMBuffer から実際の波形データを読み出して Float の配列として返すメソッド:
public static func convertBufferToArray(buffer: AVAudioPCMBuffer) -> [Float] {
let start = buffer.floatChannelData?[0]
let count = Int(buffer.frameLength)
let convertedArray = Array(UnsafeBufferPointer(start: start, count: count))
return convertedArray
}
transcribeAudioSamples メソッド
波形データを受け取り、書き起こし結果を返すメソッド。ここで WhisperKit クラスを使っている。
func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? {
guard let whisperKit = whisperKit else { return nil }
let languageCode = whisperKit.tokenizer?.languages[selectedLanguage] ?? "en"
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip = [lastConfirmedSegmentEndSeconds]
let options = DecodingOptions(
verbose: false,
task: task,
language: languageCode,
temperatureFallbackCount: 3, // limit fallbacks for realtime
sampleLength: Int(sampleLength), // reduced sample length for realtime
usePrefillPrompt: enablePromptPrefill,
usePrefillCache: enableCachePrefill,
skipSpecialTokens: !enableSpecialCharacters,
withoutTimestamps: !enableTimestamps,
clipTimestamps: seekClip
)
// Early stopping checks
let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in
DispatchQueue.main.async {
let fallbacks = Int(progress.timings.totalDecodingFallbacks)
if progress.text.count < currentText.count {
if fallbacks == self.currentFallbacks {
self.unconfirmedText.append(currentText)
} else {
print("Fallback occured: \(fallbacks)")
}
}
self.currentText = progress.text
self.currentFallbacks = fallbacks
}
// Check early stopping
let currentTokens = progress.tokens
let checkWindow = Int(compressionCheckWindow)
if currentTokens.count > checkWindow {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
return false
}
}
if progress.avgLogprob! < options.logProbThreshold! {
return false
}
return nil
}
let transcription = try await whisperKit.transcribe(audioArray: samples, decodeOptions: options, callback: decodingCallback)
return transcription
}
重要なのは最後の WhisperKit の transcribe メソッドを呼んでいるところだが、READMEで「2行でOK」と言ってたわりにやたら多い。これらは何の実装なのか?
見ていくと、transcribe メソッドの第2、第3引数に渡す DecodingOptions オブジェクトと TranscriptionCallback オブジェクトを生成しているようである。
(READMEの「2行でOK」のサンプルを見返すと、これらの引数を省略していたようだ。)
DecodingOptions
この部分:
let languageCode = whisperKit.tokenizer?.languages[selectedLanguage] ?? "en"
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip = [lastConfirmedSegmentEndSeconds]
let options = DecodingOptions(
verbose: false,
task: task,
language: languageCode,
temperatureFallbackCount: 3, // limit fallbacks for realtime
sampleLength: Int(sampleLength), // reduced sample length for realtime
usePrefillPrompt: enablePromptPrefill,
usePrefillCache: enableCachePrefill,
skipSpecialTokens: !enableSpecialCharacters,
withoutTimestamps: !enableTimestamps,
clipTimestamps: seekClip
)
細かいカスタマイズはここでできそう。いずれ詳細に見ていきたい。
TranscriptionCallback
この部分:
// Early stopping checks
let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in
DispatchQueue.main.async {
let fallbacks = Int(progress.timings.totalDecodingFallbacks)
if progress.text.count < currentText.count {
if fallbacks == self.currentFallbacks {
self.unconfirmedText.append(currentText)
} else {
print("Fallback occured: \(fallbacks)")
}
}
self.currentText = progress.text
self.currentFallbacks = fallbacks
}
// Check early stopping
let currentTokens = progress.tokens
let checkWindow = Int(compressionCheckWindow)
if currentTokens.count > checkWindow {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
return false
}
}
if progress.avgLogprob! < options.logProbThreshold! {
return false
}
return nil
}
アーリーストッピングのチェックを行っている。
パッと読んでわからなかったのでChatGPTに聞いてみた。嘘もあるかもしれないが、自分で類推するよりは妥当そうな感じがする。
以下回答:
1. 圧縮率(compressionRatio)によるチェック
let currentTokens = progress.tokens
let checkWindow = Int(compressionCheckWindow)
if currentTokens.count > checkWindow {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
return false
}
}
2. 平均対数確率(avgLogprob)によるチェック
if progress.avgLogprob! < options.logProbThreshold! {
return false
}
デコーディングのフォールバックに関する処理
if progress.text.count < currentText.count {
if fallbacks == self.currentFallbacks {
self.unconfirmedText.append(currentText)
} else {
print("Fallback occured: \(fallbacks)")
}
}
結論
最後まで読んでいただきありがとうございます!もし参考になる部分があれば、スキを押していただけると励みになります。 Twitterもフォローしていただけたら嬉しいです。 https://twitter.com/shu223/