見出し画像

新サンプラーのアイデア

はじめに

※これは、私の純粋なアイデアではなくて、私が管理しているディスコードサーバーで出ていたアイデアをもとに自分なりに落とし込んだものです。

※このアイデアを元に、どなたかが実装してくれるととても嬉しいです。
というか、考え方が間違っているかもしれませんし、指摘してもらえると嬉しいです。

※私はまったくのpython初心者です。コードに誤りがある可能性もありますし、コードを反映した結果、正しく生成が行えなくなる可能性もあります。
バックアップを取ったり、別環境を用意するなど、自己責任にてチャレンジしてください。


Restartっていうサンプラー知ってますか?

生成中に1~2回のi2iを挟むような挙動をするサンプラーです。WEBUI1111の1.6.0から導入されました。仕上がりが美しくなるため私は多用していました。

リアルタイムプレビューで見ていると、完成仕掛けていた出力が途中でノイズが入って生成しなおしているような挙動が確認できます。(この記事ではこの動作をリスタートと呼ぶことにします。)

このリスタートの挙動は、一旦収束しかけたところで、もう一度ノイズを乗せて生成を進めるという一回の生成でi2iをしちゃおうという感じのものです。このリスタート、とてもいいのですが、どんなにステップ数を増やしても最大で2回までしかリスタートしてくれないのが不満でした。

「もっとたくさんリスタートしたい!!」

これが、今回の新しいサンプラーのきっかけです。

ソースを確認しよう

Restartサンプラーのコードは、"\stable-diffusion-webui\modules\sd_samplers_extra.py"にあります。
そして、このコードの中ほどに以下の記述があります。

    steps = sigmas.shape[0] - 1
    if restart_list is None:
        if steps >= 20:
            restart_steps = 9
            restart_times = 1
            if steps >= 36:
                restart_steps = steps // 4
                restart_times = 2
            sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
            restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
        else:
            restart_list = {}

ここがリスタートの処理回数などを決定してるルーチンです。
解説すると、

コードのこの部分では、まず全ステップ数stepsが20以上かどうかをチェックしています。次に、ステップ数に応じてリスタートの回数restart_timesとリスタートする際のステップ数restart_stepsを設定しています。
stepsが20以上の場合、リスタートは1回行われ、9ステップで行われます。
stepsが36以上の場合、リスタートは2回行われ、ステップ数を4で割った値がリスタートするステップ数となります。
その後、get_sigmas_karras関数を使用して、リスタート回数とステップ数を考慮に入れた新しいシグマのスケジュールを生成しています。そして、restart_listを設定しており、これはリスタートを制御するためのパラメータ(リスタートステップ、回数、最大シグマ)を持つ辞書です。この場合、どのステップ数に対しても、シグマが0.1の時にリスタートステップ数とリスタート回数に基づいてリスタートが行われるように設定しています。ここでのrestart_listはデフォルトの設定値であり、特定のステップ数におけるリスタートの構成を定義しています。

ChatGPT

ということです。(よくわかってない)

やってみたこと

とにかく、restart_stepsとrestart_timesという変数があって、restart_timesの数だけリスタートしてくれるんだな!ということで、早速書き換えてみました。

    steps = sigmas.shape[0] - 1
    restart_steps = 8
    if restart_list is None:
        if steps >= restart_steps * 2:
            restart_times = 1
            if steps >= restart_steps * 3:
                restart_times = steps // restart_steps - 1
            sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
            restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
        else:
            restart_list = {}

リスタート回数をステップ数に合わせて無制限にしています。
何度もリスタートを繰り返すことで、矛盾や破綻が減っていきクオリティが上がっていくことが確認できました。
2行目のrestart_stepsの値を変更することで、全体のリスタート回数が変わります。この値は、リスタートするまでのステップ数を表しています。高いとリスタート回数が減り、低いと、リスタート回数が増えます。
低すぎると、リスタートあたりの生成が十分ではなくなり、クオリティに影響します。
※通常モデルの推奨は6以上ですが、LCMモデルだともっと下げられるかもしれません。

実装方法(拡張機能)

11/24追記

拡張機能で作っていただきました!!

L4Ph.moe様、ありがとうございます!
一応、以下の情報も残しておきます。拡張機能で導入した方は、試してみよう!まで飛ばしてください。

12/03追記
別の方にも、作っていただきました!

ci_ka様、ありがとうございます!
設定機能がモリモリです。さすがLab!!!

実装方法(付け焼き刃)

ただし、このままだと元々のrestartサンプラーが使えなくなってしまいますし、そうすると他の方との情報のやり取りにも混乱が生じてしまいますので、別のファイルを用意してサンプラー自体を増やしてみましょう。

以下の内容のコードを「sd_samplers_extra_m.py」という名前で保存して、先程のフォルダに入れましょう。

import torch
import tqdm
import k_diffusion.sampling


@torch.no_grad()
def restart_sampler_m(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
    """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
    Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
    If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
    """
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    step_id = 0
    from k_diffusion.sampling import to_d, get_sigmas_karras

    def heun_step(x, old_sigma, new_sigma, second_order=True):
        nonlocal step_id
        denoised = model(x, old_sigma * s_in, **extra_args)
        d = to_d(x, old_sigma, denoised)
        if callback is not None:
            callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
        dt = new_sigma - old_sigma
        if new_sigma == 0 or not second_order:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
            d_2 = to_d(x_2, new_sigma, denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
        step_id += 1
        return x

    steps = sigmas.shape[0] - 1
    restart_steps = 8
    if restart_list is None:
        if steps >= restart_steps * 2:
            restart_times = 1
            if steps >= restart_steps * 3:
                restart_times = steps // restart_steps - 1
            sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
            restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
        else:
            restart_list = {}

    restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}

    step_list = []
    for i in range(len(sigmas) - 1):
        step_list.append((sigmas[i], sigmas[i + 1]))
        if i + 1 in restart_list:
            restart_steps, restart_times, restart_max = restart_list[i + 1]
            min_idx = i + 1
            max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
            if max_idx < min_idx:
                sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
                while restart_times > 0:
                    restart_times -= 1
                    step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])

    last_sigma = None
    for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
        if last_sigma is None:
            last_sigma = old_sigma
        elif last_sigma < old_sigma:
            x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
        x = heun_step(x, old_sigma, new_sigma)
        last_sigma = new_sigma

    return x

次に、同じフォルダにある「sd_samplers_kdiffusion.py」というファイルを編集します。

4行目の末尾に「, sd_samplers_extra_m」を追加します。

from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_samplers_extra_m

38行目のRestartの次の行に以下の行を追加します。

    ('multiRestart', sd_samplers_extra_m.restart_sampler_m, ['mrestart'], {'scheduler': 'karras', "second_order": True}),

ここまで編集したら保存してWEBUI1111のコマンドプロンプトを再起動しましょう。

サンプリングに「multiRestart」が増えているはずです。

試してみよう!

それでは早速やってみましょう。
Sampling stepsは32以上(できれば50以上、推奨100)、Hires stepsは16以上がおすすめです。※たぶんモデルによって異なると思います。

サンプラー比較

リアルタイムプレビューで見ているとどんどんと変わって面白いですね!

触ってみた実感としてはRestart回数が増えると単純に品質が上がる気がします。要はi2iの回数が増えるようなものですから、よりプロンプトに忠実になってくれるのではないでしょうか。多少の破綻もなおしてくれるような気がします。

おわりに

最初にも申しましたが、私がほんとうにコーディング初心者のため、どなたかこのサンプラーを拡張機能とかで実装していただけるとありがたいです。

追記

2023/11/22 14:00
色々サンプル出して試していたら、restart_stepは6がよりよい出力ができていますので、6に変更しました。これ以前に実装して8になっている方は6に変更をお願いします。

追記の追記

2023/11/22 14:30
これくらいの誤差です。この値もSettingのsampler parametersとかでせっていできるといいですよねー。どなたか実装してもらえないかなー(チラッ

restart_step 6
restart_step 8


この記事が気に入ったらサポートをしてみませんか?