見出し画像

超簡単Pythonで株価予測(GluonTS 利用)時系列予測

PythonでGluonTSを利用して25日先までの株価予測を超簡単に時系列予測(Amazon製)

Facebook製の同様ツールについては過去の投稿をどうぞ

1. ツールインストール

$ pip install mxnet~=1.7 gluonts pandas-datareader scikit-learn

2. ファイル作成

pred.py

from gluonts.dataset.common import ListDataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.trainer import Trainer
from gluonts.dataset.util import to_pandas
import matplotlib.pyplot as plt
import pandas_datareader as pdr
from sklearn.model_selection import train_test_split

training, test = train_test_split(
 pdr.get_data_yahoo("AAPL", "2019-11-01", "2020-11-01")["Close"],
 test_size=0.2,
 shuffle=False,
)
training_data = ListDataset(
   [{"start": training.index[0], "target": training}],
   freq = "d"
)
estimator = DeepAREstimator(freq="d", prediction_length=25, trainer=Trainer(epochs=10))
predictor = estimator.train(training_data=training_data)
test_data = ListDataset(
   [{"start": test.index[0], "target": test}],
   freq = "d"
)
for test_entry, forecast in zip(test_data, predictor.predict(test_data)):
   to_pandas(test_entry)[-150:].plot(figsize=(12, 5), linewidth=2)
   forecast.plot(color='g')
plt.grid(which='both')
plt.legend(["observations", "median prediction", "90% confidence interval", "50% confidence interval"])
plt.savefig("pred.png")

3. 実行

$ python pred.py

ダウンロード (1)

以上、超簡単!

4. 参考


この記事が気に入ったらサポートをしてみませんか?