セミリアルタイム、高速イメージ・アップスケールAPI

real-SERGANを用いたアップスケールAPIサーバとクリアント側の実装例です。リアルタイムまでは届きませんでした。

環境
CUDA 12.2
PyTORCH 2.2.1
Python 3.9
GPU:RTX4090
CPU:Core™ i5-13600K

変換速度
TCP/IP版 16.5fps
FastAPI版 15fps

TCP/IP版

送受信にプロトコルによるオーバーヘッドが少ないと予想される手法です。

データをシリアライズしてTCP/IPパケットにそのまま載せています。受信側はパケット受信後にデシリアライズで元データに復元し、呼び出し側へ結果としてリターンしています。クライアント側は専用の通信関数が必要なのでサーバとクライアントが対になります。

FastAPI版

ASOGフォームで通信を行います。すなわちPOST/GETでデータを送受信できます。通信処理のオーバーヘッドが増えるので、若干パフォーマンスが低下しています。

コード

サーバ部全体のコードはTCP/IP版、FastAPI版共に最後に記載します。

TCP/IP版の実装


アーギュメント処理とイニシャライズ

モデル選択ができるよになっているのでこの部分が長いですが、やっていることは単純です。モデルパスは面倒なので削除しました。modelsディレクトリを作成してダウンロードしたモデルを入れておいてください。

def main():
    global upsampler

    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
    parser.add_argument('-n','--model_name', type=str, default='RealESRGAN_x4plus', help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | realesr-animevideov3 | realesr-general-x4v3'))
    parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
    parser.add_argument('-dn','--denoise_strength',type=float, default=0.5, help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. Only used for the realesr-general-x4v3 model'))
    parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
    parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
    parser.add_argument("--host", type=str,  default="0.0.0.0",  help="サービスを提供するip アドレスを指定。")
    parser.add_argument("--port", type=int,  default=50008,    help="サービスを提供するポートを指定。")
    args = parser.parse_args()

    # determine models according to model names
    args.model_name = args.model_name.split('.')[0]
    if args.model_name == 'RealESRGAN_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
    elif args.model_name == 'RealESRGAN_x4plus_anime_6B':  # x4 RRDBNet model with 6 blocks
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
    elif args.model_name == 'RealESRGAN_x2plus':  # x2 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
    elif args.model_name == 'realesr-animevideov3':  # x4 VGG-style model (XS size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
        netscale = 4
    elif args.model_name == 'realesr-general-x4v3':  # x4 VGG-style model (S size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4

    #+++++++++++++++++++  init  +++++++++++++++++++
    model_path = "./weights/" + args.model_name +".pth"
    print(model_path )
    print(netscale)
    # use dni to control the denoise strength
    dni_weight = None
    if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
    # restorer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=0,
        half=True,
        gpu_id=0)

通信モジュールとサーバ
単純に受信待ちで監視し、クライアントからリクエストが来れば、
up_scale(img , scale)関数を呼び、結果をクリアントへ返送するだけの簡単な処理です。

    # ++++++++++++++   TCP/IP server ++++++++
    if args.test==False:
        host=args.host
        port=args.port
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)    # ソケット定義(IPv4,TCPによるソケット)
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.bind((host,port))
        s.listen(10) # ソケット接続待受(キューの最大数を指定)
        while True:
          try:   
                try:                        # ソケット接続受信待ち
                    print(host,port,'クライアントからの接続待ち...')
                    clientsock, client_address = s.accept()
                    print("Conected client= ", client_address," Conection date_time= ",datetime.now())
                except KeyboardInterrupt:   # 接続待ちの間に強制終了が入った時の例外処理
                    clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
                    break
                else:                       # 接続待ちの間に強制終了なく、クライアントからの接続が来た場合
                    all_data=b''                   # 受信データ保存用変数の初期化
                    while True:                 # ソケット接続開始後の処理
                        data = clientsock.recv(4096*256)  # データ受信。受信バッファサイズ1024バイト
                        if not data:            # 全データ受信完了(受信路切断)時に、ループ離脱
                            break
                        all_data += data        # 受信データを追加し繋げていく
                    get_data=(pickle.loads(all_data)) #受信データ解析 元の形式にpickle.loadsで復元    get_data[0]=OpenCV-image ,  get_data[1]=scale
                    tx_gen_out=up_scale(get_data[0], float(get_data[1]))#  背景削除実行
                    # 結果をクライアントに送信
                    tx_dat=pickle.dumps(tx_gen_out,5)
                    clientsock.send(tx_dat) #pickle.dumpsでシリアライズ
                    clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
                    print(client_address," へ送信完了","Tx bytes=",len(tx_dat)/1000,"kB",datetime.now())
          except:
            print("connection error")

アップスケール関数
ここが、今回のreal-ESERの実行を行う部分です。サーバ化しない場合はこの関数をプログラムから呼び出せば直接実行できます。通信プロトコルが入らないのでパフォーマンスは上がります。モデル関係の初期化は別途必要です。

def  up_scale(img , scale):
        global upsampler
        try:
            output, _ = upsampler.enhance(img , outscale=scale)
        except RuntimeError as error:
            print('Error', error)
            print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
        return output

TESTコード
アップスケール関数が正しいかどうかを確かめるための処理です。冒頭でif args.test==True: としているので、--test Falseで実行されません。デフォルトは実行されます。

#+++++++++++++++++++ TEST +++++++++++++++++++
    if args.test==True:
        if os.path.isfile(args.input):
            paths = [args.input]
        else:
            paths = sorted(glob.glob(os.path.join(args.input, '*')))
        img_list=[]
        for idx, path in enumerate(paths):
            imgname, extension = os.path.splitext(os.path.basename(path))
            print('Testing', idx, imgname)
            cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
            img_list.append(cv_img)
        print("start_time=",datetime.now())
        count=len(img_list)
        for i in range(0,count):
            img=img_list[i]
            output = up_scale(img , args.outscale)
            if len(img.shape) == 3 and img.shape[2] == 4:
                extension = '.png'
            else:
                extension = '.jpg'
            save_path = "./results/" + args.output+ str(i)+extension
            cv2.imwrite(save_path, output) #if files are require
        print("end_time=",datetime.now())

クライント側
前半はアーギュメントとTESTプログラム、中間にTCP/IPプロトコルがあります。最後の up_scale(img , scale) をアプリにimportして呼び出せばアップスケールされた画像が受け取れます。scaleは倍率で2,4,8が指定出来ます。イニシャライズは不要です。
クライアント側全コード

import argparse
import cv2
import glob
import os
from datetime import datetime
import pickle
import socket


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
    parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
    parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
    parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
    args = parser.parse_args()

    #+++++++++++++++++++ TEST +++++++++++++++++++
    if args.test==True:
        if os.path.isfile(args.input):
            paths = [args.input]
        else:
            paths = sorted(glob.glob(os.path.join(args.input, '*')))
        img_list=[]
        for idx, path in enumerate(paths):
            imgname, extension = os.path.splitext(os.path.basename(path))
            print('Testing', idx, imgname)
            cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
            img_list.append(cv_img)
        start_time = datetime.now()
        count=len(img_list)
        for i in range(0,count):
            img=img_list[i]
            output = up_scale(img , args.outscale) # <<<<<<<<<<<<<<<<<<<<< inference関数  up_scale(img , scale)
            if len(img.shape) == 3 and img.shape[2] == 4:
                extension = '.png'
            else:
                extension = '.jpg'
            save_path = args.output + "/" + str(i)+extension
            print(save_path)
            cv2.imwrite(save_path, output) #if files are require
        print("No of pictures =",count)
        print("Start_time=", start_time)
        print("End_time  =",datetime.now())

# ++++++++++++++   TCP/IP server ++++++++
def get_out(tx_list):
    host="127.0.0.1"    # サーバーIPアドレス定義
    port=8001           # サーバー待ち受けポート番号定義
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)    # ソケットクライアント作成
    s.connect((host, port))         # 送信先サーバーに接続
    all_data=b''                    # 受信データ保存用変数の初期化
    print(host , port , "del_bkg function TX  start",datetime.now())
    tx_data=pickle.dumps(tx_list,5)
    s.send(tx_data) #pickle.dumpsで送信データをシリアライズしサーバに送信
    s.shutdown(1)# データ送信完了後、送信路を閉じる
    while True:                 # ソケット接続開始後の処理
                data = s.recv(4096*256)  # データ受信。受信バッファサイズ4096*256バイト
                if not data:            # 全データ受信完了(受信路切断)時に、ループ離脱
                    break
                all_data += data        # 受信データを追加し繋げていく
    get_out =(pickle.loads(all_data))#元の形式にpickle.loadsで復元
    print("upscal function RX done",datetime.now())
    return  get_out

# ++++++++++++++  up scale ++++++++++++++++
def  up_scale(img , scale):
         tx_list=[]
         tx_list.append(img)
         tx_list.append(scale)
         try:
            image =get_out(tx_list)
         except RuntimeError as error:
            print('Error', error)
         return  image

if __name__ == '__main__':
    main()



FastAPI版の実装

FastAPIによるASOGフォームの通信を用いた実装です。冒頭にも書いたようにパフォーマンスは若干落ちます。

コード
冒頭に up_scale(img , scale) 関数があります。real-ESERの実行を行う部分です。

# ++++++++++++++  up scale ++++++++++++++++
def  up_scale(img , scale):
        print("inf_start_time=",datetime.now())
        global upsampler
        try:
            output, _ = upsampler.enhance(img , outscale=scale)
        except RuntimeError as error:
            print('Error', error)
            print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
        print("inf_end_time=",datetime.now())
        return output

その後に続く部分がアーギュメント処理部分です。
更にTCP/IPと同じテストプログラムがあります。

通信部

FastAPIによる通信処理部です。エンドポイントは
resr_upscalのみです。クライントから受け取るデータはOpenCV形式の画像データ、パラメータはint型のscaleです。
out_img = up_scale(img ,scale)
でアップスケール処理を呼び出して、結果のイメージを返送しています。
形式がjesonではないので注意してください。

# =============    FastAPI  ============
app = FastAPI()

@app.post("/resr_upscal/")
async  def resr_upscal(file: UploadFile = File(...), scale:int  = Form(...)): #file=OpenCV
    print("scale=",scale)
    scale=float(scale)
    file_contents = await file.read()
    nparr = np.frombuffer(file_contents, np.uint8) # バイナリデータをNumPy配列に変換
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)      # OpenCVで画像として読み込む
    
    out_img  = up_scale(img ,scale)

    frame_data = pickle.dumps(out_img, 5)  # tx_dataはpklデータ、イメージのみ返送
    print("send_time=",datetime.now())
    return Response(content=frame_data, media_type="application/octet-stream")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8008)

APIクライント
こちらも前半はTCP/IPと同じで、アーギュメント処理とテストプログラム、後半に通信用関数が有ります。
import up_scale
でアプリから使うことができます。イニシャライズは不要です。本番ではテストプログラム部分は不要です。なのでわずかなコードだけです。
up_scale(url , img , scale)関数を直接プログラムに埋め込んでも動きます。
クライアント側全コード

import argparse
import cv2
import glob
import os
from datetime import datetime
import pickle
import requests

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
    parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
    parser.add_argument('-s', '--outscale', type=str, default=4, help='The final upsampling scale of the image')
    parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
    parser.add_argument("--host", type=str,  default="0.0.0.0",  help="サービスを提供するip アドレスを指定。")
    parser.add_argument("--port", type=int,  default=50008,    help="サービスを提供するポートを指定。")
    args = parser.parse_args()

    host="0.0.0.0"    # サーバーIPアドレス定義
    port=8008          # サーバー待ち受けポート番号定義
    url="http://" + host + ":" + str(port) + "/resr_upscal/"
    
    #+++++++++++++++   test    +++++++++++++++++++
    if args.test==True:
        if os.path.isfile(args.input):
            paths = [args.input]
        else:
            paths = sorted(glob.glob(os.path.join(args.input, '*')))
        img_list=[]
        for idx, path in enumerate(paths):
            imgname, extension = os.path.splitext(os.path.basename(path))
            print('Testing', idx, imgname)
            cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
            img_list.append(cv_img)
        start_time = datetime.now()
        count=len(img_list)
        for i in range(0,count):
            img=img_list[i]
            output = up_scale(url, img , args.outscale) # <<<<<<<<<<<<<<<<<<   up_scale(url , img ,  scale):
            if len(img.shape) == 3 and img.shape[2] == 4:
                extension = '.png'
            else:
                extension = '.jpg'
            save_path = args.output + "/" + str(i)+extension
            print(save_path)
            cv2.imwrite(save_path, output) #if files are require #ファイルへ書き出しをすると遅くなります。
        print("start_time=",start_time)
        print("end_time=",datetime.now())

# ++++++++++++++  up scale ++++++++++++++++
def up_scale(url , img ,  scale):
    _, img_encoded = cv2.imencode('.jpg', img)
    response = requests.post(url, files={"file": ("image.jpg", img_encoded.tobytes(), "image/jpeg"),"scale":(None,scale)})
    all_data =response.content
    up_data = (pickle.loads(all_data))#元の形式にpickle.loadsで復元
    return up_data #形式はimg_mode指定の通り

if __name__ == '__main__':
    main()

サーバの全コード

TCP/IP版

import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from datetime import datetime
import socket
import pickle

def main():
    global upsampler

    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
    parser.add_argument('-n','--model_name', type=str, default='RealESRGAN_x4plus', help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | realesr-animevideov3 | realesr-general-x4v3'))
    parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
    parser.add_argument('-dn','--denoise_strength',type=float, default=0.5, help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. Only used for the realesr-general-x4v3 model'))
    parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
    parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
    parser.add_argument("--host", type=str,  default="0.0.0.0",  help="サービスを提供するip アドレスを指定。")
    parser.add_argument("--port", type=int,  default=50008,    help="サービスを提供するポートを指定。")
    args = parser.parse_args()

    # determine models according to model names
    args.model_name = args.model_name.split('.')[0]
    if args.model_name == 'RealESRGAN_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
    elif args.model_name == 'RealESRGAN_x4plus_anime_6B':  # x4 RRDBNet model with 6 blocks
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
    elif args.model_name == 'RealESRGAN_x2plus':  # x2 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
    elif args.model_name == 'realesr-animevideov3':  # x4 VGG-style model (XS size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
        netscale = 4
    elif args.model_name == 'realesr-general-x4v3':  # x4 VGG-style model (S size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4

    #+++++++++++++++++++  init  +++++++++++++++++++
    model_path = "./weights/" + args.model_name +".pth"
    print(model_path )
    print(netscale)
    # use dni to control the denoise strength
    dni_weight = None
    if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
    # restorer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=0,
        half=True,
        gpu_id=0)

#+++++++++++++++++++ TEST +++++++++++++++++++
    if args.test==True:
        if os.path.isfile(args.input):
            paths = [args.input]
        else:
            paths = sorted(glob.glob(os.path.join(args.input, '*')))
        img_list=[]
        for idx, path in enumerate(paths):
            imgname, extension = os.path.splitext(os.path.basename(path))
            print('Testing', idx, imgname)
            cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
            img_list.append(cv_img)
        print("start_time=",datetime.now())
        count=len(img_list)
        for i in range(0,count):
            img=img_list[i]
            output = up_scale(img , args.outscale)
            if len(img.shape) == 3 and img.shape[2] == 4:
                extension = '.png'
            else:
                extension = '.jpg'
            save_path = "./results/" + args.output+ str(i)+extension
            cv2.imwrite(save_path, output) #if files are require
        print("end_time=",datetime.now())

    # ++++++++++++++   TCP/IP server ++++++++
    if args.test==False:
        host=args.host
        port=args.port
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)    # ソケット定義(IPv4,TCPによるソケット)
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.bind((host,port))
        s.listen(10) # ソケット接続待受(キューの最大数を指定)
        while True:
          try:   
                try:                        # ソケット接続受信待ち
                    print(host,port,'クライアントからの接続待ち...')
                    clientsock, client_address = s.accept()
                    print("Conected client= ", client_address," Conection date_time= ",datetime.now())
                except KeyboardInterrupt:   # 接続待ちの間に強制終了が入った時の例外処理
                    clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
                    break
                else:                       # 接続待ちの間に強制終了なく、クライアントからの接続が来た場合
                    all_data=b''                   # 受信データ保存用変数の初期化
                    while True:                 # ソケット接続開始後の処理
                        data = clientsock.recv(4096*256)  # データ受信。受信バッファサイズ1024バイト
                        if not data:            # 全データ受信完了(受信路切断)時に、ループ離脱
                            break
                        all_data += data        # 受信データを追加し繋げていく
                    get_data=(pickle.loads(all_data)) #受信データ解析 元の形式にpickle.loadsで復元    get_data[0]=OpenCV-image ,  get_data[1]=scale
                    tx_gen_out=up_scale(get_data[0], float(get_data[1]))#  背景削除実行
                    # 結果をクライアントに送信
                    tx_dat=pickle.dumps(tx_gen_out,5)
                    clientsock.send(tx_dat) #pickle.dumpsでシリアライズ
                    clientsock.shutdown(1)# データ送信完了後、送信路を閉じる
                    print(client_address," へ送信完了","Tx bytes=",len(tx_dat)/1000,"kB",datetime.now())
          except:
            print("connection error")
# ++++++++++++++  up scale ++++++++++++++++
def  up_scale(img , scale):
        global upsampler
        try:
            output, _ = upsampler.enhance(img , outscale=scale)
        except RuntimeError as error:
            print('Error', error)
            print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
        return output

if __name__ == '__main__':
    main()

FastAPI版サーバ

import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from datetime import datetime
import pickle
from fastapi import FastAPI, File, UploadFile, Form
from starlette.responses import Response
from io import BytesIO
import numpy as np

# ++++++++++++++  up scale ++++++++++++++++
def  up_scale(img , scale):
        print("inf_start_time=",datetime.now())
        global upsampler
        try:
            output, _ = upsampler.enhance(img , outscale=scale)
        except RuntimeError as error:
            print('Error', error)
            print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
        print("inf_end_time=",datetime.now())
        return output

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-n','--model_name', type=str, default='RealESRGAN_x4plus', help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-dn','--denoise_strength',type=float, default=0.5, help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument( '-t', '--test', type=bool, default=False, help='excecute test PG if True')
parser.add_argument("--host", type=str,  default="0.0.0.0",  help="サービスを提供するip アドレスを指定。")
parser.add_argument("--port", type=int,  default=50008,    help="サービスを提供するポートを指定。")
args = parser.parse_args()

# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name == 'RealESRGAN_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
elif args.model_name == 'RealESRGAN_x4plus_anime_6B':  # x4 RRDBNet model with 6 blocks
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
elif args.model_name == 'RealESRGAN_x2plus':  # x2 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
elif args.model_name == 'realesr-animevideov3':  # x4 VGG-style model (XS size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
        netscale = 4
elif args.model_name == 'realesr-general-x4v3':  # x4 VGG-style model (S size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4

 #+++++++++++++++++++  init  +++++++++++++++++++
model_path = "./weights/" + args.model_name +".pth"
print(model_path )
print(netscale)
# use dni to control the denoise strength
dni_weight = None
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
    # restorer
upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=0,
        half=True,
        gpu_id=0)

#+++++++++++++++++++ TEST +++++++++++++++++++
if args.test==True:
        if os.path.isfile(args.input):
            paths = [args.input]
        else:
            paths = sorted(glob.glob(os.path.join(args.input, '*')))
        img_list=[]
        for idx, path in enumerate(paths):
            imgname, extension = os.path.splitext(os.path.basename(path))
            print('Testing', idx, imgname)
            cv_img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
            img_list.append(cv_img)
        print("start_time=",datetime.now())
        count=len(img_list)
        for i in range(0,count):
            img=img_list[i]
            output = up_scale(img , args.outscale)
            if len(img.shape) == 3 and img.shape[2] == 4:
                extension = '.png'
            else:
                extension = '.jpg'
            save_path = "./results/" + args.output+ str(i)+extension
            cv2.imwrite(save_path, output) #if files are require
        print("end_time=",datetime.now())

# =============    FastAPI  ============
app = FastAPI()

@app.post("/resr_upscal/")
async  def resr_upscal(file: UploadFile = File(...), scale:int  = Form(...)): #file=OpenCV
    print("scale=",scale)
    scale=float(scale)
    file_contents = await file.read()
    nparr = np.frombuffer(file_contents, np.uint8) # バイナリデータをNumPy配列に変換
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)      # OpenCVで画像として読み込む
    
    out_img  = up_scale(img ,scale)

    frame_data = pickle.dumps(out_img, 5)  # tx_dataはpklデータ、イメージのみ返送
    print("send_time=",datetime.now())
    return Response(content=frame_data, media_type="application/octet-stream")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8008)