2024年2月に発表された時系列予測のためのLag-Llamaモデル


Lag-Llamaとは?

Lag-Llamaは最近発表された時系列予測のためのオープンソースモデルです。
様々なドメインにわたる広範囲の時系列データでトレーニングされています。

公式のGithubレポジトリにあるデモを参考にLag-Llamaでゼロショット予測を試してみます。環境はGoogle Corabです。

インストール

Lag-LlamaのGithubレポジトリから取得してインストールを行います。

!git clone https://github.com/time-series-foundation-models/lag-llama/
cd /content/lag-llama
!pip install -r requirements.txt --quiet 

事前学習済みのモデルのウェイトを取得

HuggingFaceから事前学習済みのモデルのウェイトをダウンロードしています。

!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama

ライブラリの読み込み

from itertools import islice

from matplotlib import pyplot as plt
import matplotlib.dates as mdates

import torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset

from gluonts.dataset.pandas import PandasDataset
import pandas as pd

from lag_llama.gluon.estimator import LagLlamaEstimator

Lag-Llamaで予測を行うための関数

def get_lag_llama_predictions(dataset, prediction_length, num_samples=100):
    ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0')) # GPUを利用しています。
    estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

    estimator = LagLlamaEstimator(
        ckpt_path="lag-llama.ckpt",
        prediction_length=prediction_length,
        context_length=32, # 事前学習済みのモデルが訓練された設定であるため、変更してはいけない。

        # estimator args
        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],

        batch_size=1,
        num_parallel_samples=100
    )

    lightning_module = estimator.create_lightning_module()
    transformation = estimator.create_transformation()
    predictor = estimator.create_predictor(transformation, lightning_module)

    forecast_it, ts_it = make_evaluation_predictions(
        dataset=dataset,
        predictor=predictor,
        num_samples=num_samples
    )
    forecasts = list(forecast_it)
    tss = list(ts_it)

    return forecasts, tss

予測を行うための関数を見てみます。大まかに以下の処理を行なっています。

  1. 取得した事前学習済みのモデルをロード

  2. ハイパーパラメーターを取得

  3. LagLlamaEstimatorで推定器を作成

  4. 予測器を作成(create_predictor)

  5. 予測(make_evaluation_predictions)

  6. 最終的に予測された時系列データを返します

1で先ほど取得した事前学習済みのモデルのウェイトのチェックポイントファイルであるlag-llama.ckptをロードしています。cuda:0を指定しているため、GPU環境を想定しています。Google Corabの場合は、ランタイムのタイプを変更することでGPU環境に変更できます。

データセットの読み込み

import pandas as pd
from gluonts.dataset.pandas import PandasDataset

url = (
    "https://gist.githubusercontent.com/rsnirwan/a8b424085c9f44ef2598da74ce43e7a3"
    "/raw/b6fdef21fe1f654787fa0493846c546b7f9c4df2/ts_long.csv"
)
df = pd.read_csv(url, index_col=0, parse_dates=True)
df
for col in df.columns:
    if df[col].dtype != 'object' and pd.api.types.is_string_dtype(df[col]) == False:
        df[col] = df[col].astype('float32')

dataset = PandasDataset.from_long_dataframe(df, target="target", item_id="item_id")

backtest_dataset = dataset
prediction_length = 24  # 予測期間を定義します。ここではデータが1時間毎の頻度であるため、24を使用します。
num_samples = 100 # 確率分布からサンプリングされるサンプルの数です。

予測を行う

先ほど定義した関数で予測を行います。

forecasts, tss = get_lag_llama_predictions(backtest_dataset, prediction_length, num_samples)

予測の可視化

plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()