見出し画像

WhisperKitサンプルのソースコードを読む

先日WhisperKitについて、ざっくり紹介するLTを行った:

で、紹介しきれなかったことはこちらの記事に書いた:

さらに書ききれなかったのが、サンプルのコードリーディング。

READMEには以下の2行で実装できると記載してあったのだが、

let pipe = try? await WhisperKit()
let transcription = try? await pipe!.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text
print(transcription)

本当に2行で実装してるの?というところをサンプルコードを追いかけてみた。

結論としては2行でできるのも嘘ではないが、サンプルはもっといろいろやっている。(まぁそうだろうな、という妥当な結論)


主な処理の流れ

サンプルアプリの実装はほぼすべてが ContentView.swift に書かれている。

オーディオファイルの書き起こしは、ピッカーでファイル選択したところで `handleFilePicker(result:)` メソッドが呼ばれ、その中で `transcribeFile(path:)` メソッドが呼ばれ、その中で `transcribeCurrentFile(path:)` メソッドが呼ばれる。

ここらへんからが音声認識処理の実装で、ざっくり次の3行が「音声ファイルをWhisperKitで書き起こす」処理のコア部分。

guard let audioFileBuffer = AudioProcessor.loadAudio(fromPath: path) else {
    return
}

let audioFileSamples = AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
let transcription = try await transcribeAudioSamples(audioFileSamples)

loadAudio メソッド

引数に渡されたファイルパスから音声データを読み出して AVAudioPCMBuffer 型を返すメソッド。

public static func loadAudio(fromPath audioFilePath: String) -> AVAudioPCMBuffer? {
    guard FileManager.default.fileExists(atPath: audioFilePath) else {
        Logging.error("Resource path does not exist \(audioFilePath)")
        return nil
    }

    var outputBuffer: AVAudioPCMBuffer?

    do {
        let audioFileURL = URL(fileURLWithPath: audioFilePath)
        let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)

        let sampleRate = audioFile.fileFormat.sampleRate
        let channelCount = audioFile.fileFormat.channelCount
        let frameLength = AVAudioFrameCount(audioFile.length)

        // If the audio file already meets the desired format, read directly into the output buffer
        if sampleRate == 16000 && channelCount == 1 {
            guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameLength) else {
                Logging.error("Unable to create audio buffer")
                return nil
            }
            try audioFile.read(into: buffer)
            outputBuffer = buffer
        } else {
            // Audio needs resampling to 16khz
            guard let buffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1) else {
                Logging.error("Unable to resample audio")
                return nil
            }
            outputBuffer = buffer
        }

        if let buffer = outputBuffer {
            Logging.info("Audio source details - Sample Rate: \(sampleRate) Hz, Channel Count: \(channelCount), Frame Length: \(frameLength), Duration: \(Double(frameLength) / sampleRate)s")
            Logging.info("Audio buffer details - Sample Rate: \(buffer.format.sampleRate) Hz, Channel Count: \(buffer.format.channelCount), Frame Length: \(buffer.frameLength), Duration: \(Double(buffer.frameLength) / buffer.format.sampleRate)s")
        }
    } catch {
        Logging.error("Error loading audio file: \(error)")
        return nil
    }

    return outputBuffer
}

期待するフォーマット(サンプルレート16k, 1チャンネル)でない場合に再サンプリングする実装(resampleAudio メソッド)も入っている。

resampleAudio の実装までは再掲しないが、ファイルからのオーディオ読み出し時には参考になりそう。

convertBufferToArray メソッド

AVAudioPCMBuffer から実際の波形データを読み出して Float の配列として返すメソッド:

public static func convertBufferToArray(buffer: AVAudioPCMBuffer) -> [Float] {
    let start = buffer.floatChannelData?[0]
    let count = Int(buffer.frameLength)
    let convertedArray = Array(UnsafeBufferPointer(start: start, count: count))
    return convertedArray
}

transcribeAudioSamples メソッド

波形データを受け取り、書き起こし結果を返すメソッド。ここで WhisperKit クラスを使っている。

func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? {
    guard let whisperKit = whisperKit else { return nil }

    let languageCode = whisperKit.tokenizer?.languages[selectedLanguage] ?? "en"
    let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
    let seekClip = [lastConfirmedSegmentEndSeconds]

    let options = DecodingOptions(
        verbose: false,
        task: task,
        language: languageCode,
        temperatureFallbackCount: 3, // limit fallbacks for realtime
        sampleLength: Int(sampleLength), // reduced sample length for realtime
        usePrefillPrompt: enablePromptPrefill,
        usePrefillCache: enableCachePrefill,
        skipSpecialTokens: !enableSpecialCharacters,
        withoutTimestamps: !enableTimestamps,
        clipTimestamps: seekClip
    )

    // Early stopping checks
    let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in
        DispatchQueue.main.async {
            let fallbacks = Int(progress.timings.totalDecodingFallbacks)
            if progress.text.count < currentText.count {
                if fallbacks == self.currentFallbacks {
                    self.unconfirmedText.append(currentText)
                } else {
                    print("Fallback occured: \(fallbacks)")
                }
            }
            self.currentText = progress.text
            self.currentFallbacks = fallbacks
        }
        // Check early stopping
        let currentTokens = progress.tokens
        let checkWindow = Int(compressionCheckWindow)
        if currentTokens.count > checkWindow {
            let checkTokens: [Int] = currentTokens.suffix(checkWindow)
            let compressionRatio = compressionRatio(of: checkTokens)
            if compressionRatio > options.compressionRatioThreshold! {
                return false
            }
        }
        if progress.avgLogprob! < options.logProbThreshold! {
            return false
        }

        return nil
    }

    let transcription = try await whisperKit.transcribe(audioArray: samples, decodeOptions: options, callback: decodingCallback)
    return transcription
}

重要なのは最後の WhisperKit の transcribe メソッドを呼んでいるところだが、READMEで「2行でOK」と言ってたわりにやたら多い。これらは何の実装なのか?

見ていくと、transcribe メソッドの第2、第3引数に渡す DecodingOptions オブジェクトと TranscriptionCallback オブジェクトを生成しているようである。

(READMEの「2行でOK」のサンプルを見返すと、これらの引数を省略していたようだ。)

DecodingOptions

この部分:

let languageCode = whisperKit.tokenizer?.languages[selectedLanguage] ?? "en"
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip = [lastConfirmedSegmentEndSeconds]

let options = DecodingOptions(
    verbose: false,
    task: task,
    language: languageCode,
    temperatureFallbackCount: 3, // limit fallbacks for realtime
    sampleLength: Int(sampleLength), // reduced sample length for realtime
    usePrefillPrompt: enablePromptPrefill,
    usePrefillCache: enableCachePrefill,
    skipSpecialTokens: !enableSpecialCharacters,
    withoutTimestamps: !enableTimestamps,
    clipTimestamps: seekClip
)

細かいカスタマイズはここでできそう。いずれ詳細に見ていきたい。

TranscriptionCallback

この部分:

// Early stopping checks
let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in
    DispatchQueue.main.async {
        let fallbacks = Int(progress.timings.totalDecodingFallbacks)
        if progress.text.count < currentText.count {
            if fallbacks == self.currentFallbacks {
                self.unconfirmedText.append(currentText)
            } else {
                print("Fallback occured: \(fallbacks)")
            }
        }
        self.currentText = progress.text
        self.currentFallbacks = fallbacks
    }
    // Check early stopping
    let currentTokens = progress.tokens
    let checkWindow = Int(compressionCheckWindow)
    if currentTokens.count > checkWindow {
        let checkTokens: [Int] = currentTokens.suffix(checkWindow)
        let compressionRatio = compressionRatio(of: checkTokens)
        if compressionRatio > options.compressionRatioThreshold! {
            return false
        }
    }
    if progress.avgLogprob! < options.logProbThreshold! {
        return false
    }

    return nil
}

アーリーストッピングのチェックを行っている。

パッと読んでわからなかったのでChatGPTに聞いてみた。嘘もあるかもしれないが、自分で類推するよりは妥当そうな感じがする。

Q: 以下のコードは、WhisperKitというOpenAIのWhisperのCore ML版OSSの実装で、書き起こし処理中に渡せるコールバック実装部分にかかれていた実装です。アーリーストッピングを行っているようですが、どういう条件で何のために停止を行っていると考えられますか?変数名などから類推してください

以下回答:

このコードは、音声の自動書き起こしを行う際のアーリーストッピングのロジックを含んでいるようです。アーリーストッピングは、特定の条件が満たされた場合に処理を早期に終了させるテクニックで、無駄な計算資源の消費を避けたり、品質の低下を防ぐために使われます。このコードから読み取れるアーリーストッピングの条件と目的を、変数名やコードの構造から類推してみましょう。

1. 圧縮率(compressionRatio)によるチェック

let currentTokens = progress.tokens
let checkWindow = Int(compressionCheckWindow)
if currentTokens.count > checkWindow {
    let checkTokens: [Int] = currentTokens.suffix(checkWindow)
    let compressionRatio = compressionRatio(of: checkTokens)
    if compressionRatio > options.compressionRatioThreshold! {
        return false
    }
}

ここでは、最近のトークン群(`checkTokens`)の圧縮率を計算し、それが設定された閾値(`compressionRatioThreshold`)を超えた場合に処理を停止します。圧縮率が高いとは、多くの重複または冗長性が存在することを意味し、これ以上の処理が品質向上に寄与しない可能性があるため、アーリーストッピングの条件として利用されています。

2. 平均対数確率(avgLogprob)によるチェック

if progress.avgLogprob! < options.logProbThreshold! {
    return false
}

平均対数確率(`avgLogprob`)がある閾値(`logProbThreshold`)未満の場合、処理を停止します。この対数確率は、モデルが生成したトークンの確信度を示し、低い値はモデルがその予測に自信がない、すなわち書き起こしの品質が低い可能性があることを意味します。したがって、この値が低いと判断された場合、処理を早期に終了させることで、不正確な書き起こしを避けることができます。

デコーディングのフォールバックに関する処理

if progress.text.count < currentText.count {
    if fallbacks == self.currentFallbacks {
        self.unconfirmedText.append(currentText)
    } else {
        print("Fallback occured: \(fallbacks)")
    }
}

この部分は直接的なアーリーストッピングの条件ではありませんが、デコーディング中に「フォールバック」が発生したことを検出しています。フォールバックが発生すると、書き起こしの品質が低下する可能性があるため、この情報は品質管理やデバッグに役立ちます。

結論

このコードによるアーリーストッピングは、書き起こしの品質を維持しつつ、無駄な計算資源の消費を避けるために設計されています。圧縮率が高いことによる冗長性の発生、またはモデルの自信の低い予測を防ぐことにより、効率的かつ高品質な書き起こしを目指しています。


ここから先は

0字
文章やサンプルコードは多少荒削りかもしれませんが、ブログや書籍にはまだ書いていないことを日々大量に載せています。たったの400円で、すぐに購読解除してもその月は過去記事もさかのぼって読めるので、少しでも気になる内容がある方にはオトクかと思います。

技術的なメモやサンプルコード、思いついたアイデア、考えたこと、お金の話等々、頭をよぎった諸々を気軽に垂れ流しています。

最後まで読んでいただきありがとうございます!もし参考になる部分があれば、スキを押していただけると励みになります。 Twitterもフォローしていただけたら嬉しいです。 https://twitter.com/shu223/