初心者のRNN(LSTM) | Kerasで試してみるを参考にLSTMを使った未来予測をやってみた。
PyCharmではそのまま実行できないのでライブラリ読み込みを
from tensorflow.python.keras.models import Sequential
Adamに関しては
from tensorflow.python.keras.optimizers import adam_v2
に変更した。これに伴い
optimizer = adam_v2.Adam(lr=0.001)
も変更を加えている。以下ソース
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Activation
from tensorflow.python.keras.layers import LSTM
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.callbacks import EarlyStopping
import numpy as np
import matplotlib.pyplot as plt
def myfunc(t, Fm, f, phi):
return Fm * np.sin(2 * np.pi * f * t + phi) + np.sin(5 * np.pi * f * t + phi)
# sin波にノイズを付与する
def toy_problem(T=100, ampl=0.05):
t = np.arange(0, 2 * T + 1)
noise = ampl * np.random.uniform(low=-1.0, high=1.0, size=len(t))
Fm = 1
f = 0.1
phi = 0
return myfunc(t, Fm, f, phi) + noise
f = toy_problem()
def make_dataset(low_data, n_prev=100):
data, target = [], []
maxlen = 25
for i in range(len(low_data) - maxlen):
data.append(low_data[i:i + maxlen])
target.append(low_data[i + maxlen])
re_data = np.array(data).reshape(len(data), maxlen, 1)
re_target = np.array(target).reshape(len(data), 1)
return re_data, re_target
# g -> 学習データ,h -> 学習ラベル
g, h = make_dataset(f)
# モデル構築
# 1つの学習データのStep数(今回は25)
length_of_sequence = g.shape[1]
in_out_neurons = 1
n_hidden = 300
model = Sequential()
model.add(LSTM(n_hidden, batch_input_shape=(None, length_of_sequence, in_out_neurons), return_sequences=False))
model.add(Dense(in_out_neurons))
model.add(Activation("linear"))
optimizer = adam_v2.Adam(lr=0.001)
model.compile(loss="mean_squared_error", optimizer=optimizer)
early_stopping = EarlyStopping(monitor='val_loss', mode='auto', patience=20)
model.fit(g, h,
batch_size=300,
epochs=100,
validation_split=0.1,
callbacks=[early_stopping]
)
# 予測
predicted = model.predict(g)
plt.figure()
plt.plot(range(25, len(predicted) + 25), predicted, color="r", label="predict_data")
plt.plot(range(0, len(f)), f, color="b", label="row_data")
plt.legend()
plt.show()
結果。
コメント