見出し画像

Motion Diffusion Model によるテキストからのモーション生成を試す

「Human Motion Diffusion Model 」によるテキストからの3D生成を試したので、まとめてみました。

1. Motion Diffusion Model

「Motion Diffusion Model」は、テキストから人物のモーションを生成することができる手法です。

2. インストール

Google Colabでの「Motion Diffusion Model」のインストール手順は、次のとおりです。

(1) メニュー「編集→ノートブックの設定」で、「ハードウェアアクセラレータ」に「GPU」を選択。

(2) # GPUの確認。

!nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

(3) パッケージのインストール。

# パッケージのインストール
!pip install git+https://github.com/openai/CLIP.git
!pip install smplx chumpy

(4) motion-diffusion-modelのクローン。
出力結果の保存先フォルダも準備します。

# motion-diffusion-modelのクローン
!git clone https://github.com/GuyTevet/motion-diffusion-model
!mkdir motion-diffusion-model/save

(5) HumanML3Dデータの取得。
データの解凍に少し時間がかかります。

# HumanML3Dデータの取得
!git clone https://github.com/EricGuo5513/HumanML3D.git
!unzip ./HumanML3D/HumanML3D/texts.zip -d ./HumanML3D/HumanML3D/
!cp -r HumanML3D/HumanML3D motion-diffusion-model/dataset/HumanML3D

(6) SMPLモデルの取得。

# SMPLモデルの取得
%cd motion-diffusion-model
!bash prepare/download_smpl_files.sh

(7) motion-diffusion-modelリポジトリから 事前学習モデル (humanml-encoder-512) をダウンロードし、Googleドライブのマイドライブ直下にアップロード。

(8) Googleドライブのマウント。

# Googleドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

(9) saveフォルダに移動し、事前学習モデル (humanml-encoder-512) をコピーして解凍。

# 事前学習モデルをsaveにコピーして解凍
%cd save
!cp '/content/drive/My Drive/humanml_trans_enc_512.zip' .
!unzip ./humanml_trans_enc_512.zip

(10) motion-diffusion-modelフォルダに移動。

%cd ..

3. 推論の実行

推論の実行手順は、次のとおりです。

(1) sampleスクリプトを実行。

!python -m sample --model_path ./save/humanml_trans_enc_512/model000200000.pt --text_prompt "the person walked forward and is picking up his toolbox."

・--model_path : 事前学習モデル (humanml-encoder-512) のパス
・--text_prompt : 生成するモーションを説明するテキストプロンプト
・--device : デバイスID
・--seed : 乱数シード
・--motion_length : モーションの長さ (最大9.8[sec])

the person walked forward and is picking up his toolbox.
人は前に歩き、ツールボックスを拾う

(2) 「motion-diffusion-model/save/humanml_trans_enc_512/samples_humanml_trans_enc_512_000200000_seed10_the_person_walked_forward_and_is_picking_up_his_toolbox」に以下のファイルが出力されているのでダウンロード。

・sampleXX_repXX.mp4 : 生成された各モーションのスティックフィギュアアニメーション。
・results.npy : テキストプロンプトと生成されたアニメーションのxyz位置を含む npy ファイル

(3) 「sampleXX_repXX.mp4」を再生して、生成されたモーションを確認。

(4) 「results.npy」の中身の確認。

# results.npyの中身の確認
import numpy as np
data = np.load('./save/humanml_trans_enc_512/samples_humanml_trans_enc_512_000200000_seed10_the_person_walked_forward_and_is_picking_up_his_toolbox/results.npy', allow_pickle=True)
print(data)
{
    'motion': array([
       [[[ 0.00000000e+00,  5.21681388e-04,  1.82657817e-03, ...,
           1.00562662e-01,  1.07239805e-01,  1.13251105e-01],
         [ 9.47966337e-01,  9.47464526e-01,  9.46544170e-01, ...,
           9.41673279e-01,  9.40844715e-01,  9.39939260e-01],
         [ 0.00000000e+00,  4.12540045e-03,  8.82862695e-03, ...,
           6.39105797e-01,  6.47692919e-01,  6.56137586e-01]],

        [[ 5.14577143e-02,  5.19877449e-02,  5.33060804e-02, ...,
           1.53042093e-01,  1.59575894e-01,  1.65455416e-01],
         [ 8.26417565e-01,  8.25979412e-01,  8.25160027e-01, ...,
           8.17770720e-01,  8.16819251e-01,  8.15796733e-01],
         [-3.43051460e-03,  9.58818942e-04,  6.07601227e-03, ...,
           6.47698402e-01,  6.55840993e-01,  6.64001107e-01]],

        [[-6.35429397e-02, -6.28210232e-02, -6.11487776e-02, ...,
           4.07312848e-02,  4.72071692e-02,  5.31035736e-02],
         [ 8.21423233e-01,  8.20960581e-01,  8.19941163e-01, ...,
           8.17932487e-01,  8.17154706e-01,  8.16219032e-01],
         [-1.14909504e-02, -7.42466655e-03, -3.30242980e-03, ...,
           6.01573944e-01,  6.10398531e-01,  6.19079053e-01]],

        ...,

        [[-2.20628113e-01, -2.23500967e-01, -2.25722209e-01, ...,
          -1.67395502e-01, -1.58353239e-01, -1.50386378e-01],
         [ 1.05493081e+00,  1.05443764e+00,  1.05474639e+00, ...,
           1.13161087e+00,  1.12797856e+00,  1.12402344e+00],
         [-4.44135815e-02, -3.84659730e-02, -3.16540971e-02, ...,
           8.24253201e-01,  8.34537327e-01,  8.43444705e-01]],

        [[ 2.92082369e-01,  2.94431478e-01,  2.98463970e-01, ...,
           3.18692923e-01,  3.24238181e-01,  3.29352677e-01],
         [ 8.69390011e-01,  8.69472921e-01,  8.69239807e-01, ...,
           1.09961760e+00,  1.09582067e+00,  1.09134543e+00],
         [ 5.77146709e-02,  6.70738444e-02,  7.93055519e-02, ...,
           9.38503265e-01,  9.47341323e-01,  9.55910206e-01]],

        [[-3.00316632e-01, -3.05910826e-01, -3.11715782e-01, ...,
          -1.14041954e-01, -1.02489509e-01, -9.20391828e-02],
         [ 8.21832001e-01,  8.22073698e-01,  8.23741674e-01, ...,
           1.05188227e+00,  1.04469633e+00,  1.03700256e+00],
         [ 4.29062769e-02,  5.08630946e-02,  5.94698042e-02, ...,
           1.06792665e+00,  1.07625794e+00,  1.08358729e+00]]],


       [[[ 0.00000000e+00,  2.42193695e-03,  5.09894313e-03, ...,
           9.09861252e-02,  9.18623060e-02,  9.27472413e-02],
         [ 9.69678342e-01,  9.75120604e-01,  9.77426171e-01, ...,
           9.73648787e-01,  9.73459065e-01,  9.73240376e-01],
         [ 0.00000000e+00,  2.10079458e-02,  4.07000929e-02, ...,
           1.03705716e+00,  1.03851998e+00,  1.04011214e+00]],

        [[ 5.41930683e-02,  5.60734123e-02,  5.81541918e-02, ...,
           1.47804439e-01,  1.48697391e-01,  1.49710342e-01],
         [ 8.48918736e-01,  8.53796721e-01,  8.55770111e-01, ...,
           8.63021672e-01,  8.62857163e-01,  8.62674534e-01],
         [-4.43500094e-03,  1.62393469e-02,  3.51153687e-02, ...,
           9.98340130e-01,  9.99824345e-01,  1.00149632e+00]],

        [[-6.25008792e-02, -6.09138235e-02, -5.88819496e-02, ...,
           2.23487243e-02,  2.32608840e-02,  2.42199898e-02],
         [ 8.42434704e-01,  8.48799169e-01,  8.51793826e-01, ...,
           8.50406706e-01,  8.50193143e-01,  8.49926770e-01],
         [-1.72611810e-02,  4.78518754e-03,  2.53752656e-02, ...,
           1.00569499e+00,  1.00690293e+00,  1.00818288e+00]],

        ...,

        [[-2.17303038e-01, -2.20877424e-01, -2.24588811e-01, ...,
          -1.02807380e-01, -1.02175310e-01, -1.02075279e-01],
         [ 1.15032053e+00,  1.16098785e+00,  1.16755974e+00, ...,
           1.04584110e+00,  1.04529595e+00,  1.04489505e+00],
         [-1.97679307e-02,  3.17545235e-03,  2.65420638e-02, ...,
           1.36158717e+00,  1.36340880e+00,  1.36487401e+00]],

        [[ 2.41876945e-01,  2.31014282e-01,  2.15778545e-01, ...,
           3.33678395e-01,  3.34153235e-01,  3.34101617e-01],
         [ 8.60424757e-01,  8.79281461e-01,  9.00060058e-01, ...,
           9.48667049e-01,  9.48654711e-01,  9.49077249e-01],
         [ 1.80547863e-01,  2.44013414e-01,  3.06093395e-01, ...,
           1.40646398e+00,  1.40883064e+00,  1.41221523e+00]],

        [[-3.28947186e-01, -3.44300836e-01, -3.59433144e-01, ...,
          -5.94693348e-02, -5.95761985e-02, -5.96984476e-02],
         [ 9.48598504e-01,  9.74688232e-01,  9.96465802e-01, ...,
           9.11210537e-01,  9.10721719e-01,  9.10955071e-01],
         [ 1.09543413e-01,  1.47344708e-01,  1.81637138e-01, ...,
           1.58263636e+00,  1.58457553e+00,  1.58625627e+00]]],


       [[[ 0.00000000e+00, -3.13731260e-04, -7.89737154e-04, ...,
          -1.41580164e-01, -1.42017037e-01, -1.42441109e-01],
         [ 9.74520683e-01,  9.74882841e-01,  9.75454509e-01, ...,
           9.13187742e-01,  9.12859380e-01,  9.11898911e-01],
         [ 0.00000000e+00, -8.00634036e-04, -2.01892294e-03, ...,
           5.00182986e-01,  5.01572669e-01,  5.02980769e-01]],

        [[ 5.72645999e-02,  5.68254925e-02,  5.63361384e-02, ...,
          -8.32566321e-02, -8.37362856e-02, -8.41341466e-02],
         [ 8.56358588e-01,  8.56758714e-01,  8.57451022e-01, ...,
           7.98389077e-01,  7.98091650e-01,  7.97164261e-01],
         [-1.16833188e-02, -1.27542624e-02, -1.43027790e-02, ...,
           4.75960046e-01,  4.77127671e-01,  4.78313923e-01]],

        [[-6.25379086e-02, -6.27380013e-02, -6.30299300e-02, ...,
          -2.02942595e-01, -2.03346491e-01, -2.03697920e-01],
         [ 8.46689582e-01,  8.47185135e-01,  8.47793818e-01, ...,
           7.87605941e-01,  7.87190974e-01,  7.86136806e-01],
         [-1.60415657e-02, -1.67257469e-02, -1.81508716e-02, ...,
           4.82499808e-01,  4.83950227e-01,  4.85381275e-01]],

        ...,

        [[-2.51363933e-01, -2.50567734e-01, -2.50664294e-01, ...,
          -3.89232129e-01, -3.89436722e-01, -3.89753819e-01],
         [ 1.09516764e+00,  1.09460890e+00,  1.09331465e+00, ...,
           9.36930537e-01,  9.36616659e-01,  9.35517073e-01],
         [-1.33895064e-02, -1.77621394e-02, -2.38986034e-02, ...,
           8.01327109e-01,  8.02681446e-01,  8.03783655e-01]],

        [[ 3.08345109e-01,  3.06183219e-01,  3.03645134e-01, ...,
           1.99581623e-01,  1.99398726e-01,  1.98907956e-01],
         [ 1.06077635e+00,  1.06301141e+00,  1.06278265e+00, ...,
           7.04766512e-01,  7.05328465e-01,  7.04832911e-01],
         [ 3.35419685e-01,  3.39406341e-01,  3.43468368e-01, ...,
           8.16992402e-01,  8.19405913e-01,  8.23145032e-01]],

        [[-3.48109484e-01, -3.47106278e-01, -3.46241534e-01, ...,
          -4.24361497e-01, -4.24858183e-01, -4.25408721e-01],
         [ 9.94770706e-01,  9.92363334e-01,  9.86825883e-01, ...,
           9.23993468e-01,  9.21954751e-01,  9.18652356e-01],
         [ 2.23184496e-01,  2.19174221e-01,  2.11503133e-01, ...,
           1.06492472e+00,  1.06594729e+00,  1.06664896e+00]]]],
      dtype=float32), 
    'text': [
        'the person walked forward and is picking up his toolbox.', 
        'the person walked forward and is picking up his toolbox.', 
        'the person walked forward and is picking up his toolbox.'], 
    'lengths': array([120, 120, 120]), 
    'num_samples': 1, 
    'num_repetitions': 3
}

【おまけ】 results.npyをJSONに変換

results.npyをJSONに変換する手順は、次のとおりです。

(1) results.npyを読み込んで、ndarrayをlistに変換。

# results.npyの読み込み
import numpy as np
data = np.load("results.npy", allow_pickle=True)
dataDict = dict(enumerate(data.flatten()))[0]
dataDict["motion"] = dataDict["motion"].tolist()
dataDict["lengths"] = dataDict["lengths"].tolist()
print(dataDict)

(2) サイズの確認。

# サイズの確認
print("repetitions:", len(dataDict["motion"])) # 繰り返し
print("joint:", len(dataDict["motion"][0]))  # 関節
print("xyz:", len(dataDict["motion"][0][0])) # XYZ
print("frame:", len(dataDict["motion"][0][0][0])) # フレーム
repetitions: 3
joint: 22
xyz: 3
frame: 120

(3) results.jsonに出力。

# results.jsonに出力
import json
with open("results.json", mode='w') as f:
    json.dump(dataDict, f)

参考

次回



この記事が気に入ったらサポートをしてみませんか?