|
|
@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine, text
|
|
|
|
|
|
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
|
from keras.callbacks import EarlyStopping, ModelCheckpoint
|
|
|
|
|
|
|
|
from sklearn.metrics import mean_squared_error
|
|
|
|
|
|
|
|
from keras.models import Sequential
|
|
|
|
|
|
|
|
from keras.layers import LSTM, Dense
|
|
|
|
|
|
|
|
from tensorflow.keras.layers import Input
|
|
|
|
|
|
|
|
from keras.preprocessing.sequence import pad_sequences
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
from tensorflow.keras.models import Sequential
|
|
|
|
|
|
|
|
from tensorflow.keras.layers import LSTM, Dense
|
|
|
|
|
|
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
|
|
|
|
|
|
|
from tensorflow.keras.models import Model # 导入Model类
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
import joblib
|
|
|
|
|
|
|
|
import pymysql
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
def get_db_config():
|
|
|
|
|
|
|
|
#从环境变量中获取数据库配置。
|
|
|
|
|
|
|
|
config = {
|
|
|
|
|
|
|
|
"host": os.getenv('DB_HOST', '127.0.0.1'),
|
|
|
|
|
|
|
|
"user": os.getenv('DB_USER', 'root'),
|
|
|
|
|
|
|
|
"password": os.getenv('DB_PASSWORD','mysql>hyx123'),
|
|
|
|
|
|
|
|
"db": os.getenv('DB_NAME', 'airquility'),
|
|
|
|
|
|
|
|
"charset": 'utf8',
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查环境变量是否已设置
|
|
|
|
|
|
|
|
for key, value in config.items():
|
|
|
|
|
|
|
|
if value is None:
|
|
|
|
|
|
|
|
raise ValueError(f"缺少环境变量: {key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_database_engine(config):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
创建数据库引擎。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
db_string = f'mysql+pymysql://{config["user"]}:{config["password"]}@{config["host"]}/{config["db"]}?charset={config["charset"]}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
engine = create_engine(db_string)
|
|
|
|
|
|
|
|
print(type(engine))
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
print(f"创建数据库引擎失败: {e}")
|
|
|
|
|
|
|
|
# 根据需要处理异常,例如记录日志或重试
|
|
|
|
|
|
|
|
raise # 如果需要将异常继续抛出
|
|
|
|
|
|
|
|
return engine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_data(engine, query):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
从数据库中获取数据。
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
|
|
engine: 数据库连接引擎对象。
|
|
|
|
|
|
|
|
query: SQL查询字符串。
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
|
|
查询结果的数据框(df)。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 确保query是字符串类型
|
|
|
|
|
|
|
|
if not isinstance(query, str):
|
|
|
|
|
|
|
|
logging.error("查询字符串类型错误,query应为字符串。")
|
|
|
|
|
|
|
|
raise ValueError("查询字符串类型错误,query应为字符串。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not query.strip():
|
|
|
|
|
|
|
|
logging.error("查询字符串为空。")
|
|
|
|
|
|
|
|
raise ValueError("查询字符串为空。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
df = pd.read_sql(text(query), engine)
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
logging.error(f"执行SQL查询失败: {e}")
|
|
|
|
|
|
|
|
raise # 重新抛出异常以便上层处理
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_data(df, target_col, default_year=2024, features=None):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
对数据进行预处理,包括日期列转换、特征标准化等。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 检查df是否为空
|
|
|
|
|
|
|
|
if df.empty:
|
|
|
|
|
|
|
|
logging.error("输入的DataFrame为空")
|
|
|
|
|
|
|
|
return None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查'ptime'列是否存在
|
|
|
|
|
|
|
|
if 'ptime' not in df.columns:
|
|
|
|
|
|
|
|
logging.error("DataFrame中不存在'ptime'列")
|
|
|
|
|
|
|
|
return None, None, None
|
|
|
|
|
|
|
|
default_year = 2024
|
|
|
|
|
|
|
|
df['ptime'] = df['ptime'].apply(lambda x: datetime.strptime(f"{default_year}/{x}", "%Y/%m/%d"))
|
|
|
|
|
|
|
|
# 或者,如果使用pd.to_datetime,并且'ptime'格式特殊,需要指定格式
|
|
|
|
|
|
|
|
# df['ptime'] = pd.to_datetime(df['ptime'], errors='coerce', format='%m/%d').dt.strftime('%Y-%m-%d')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(df.head)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 如果'ptime'已经是datetime类型,则无需转换
|
|
|
|
|
|
|
|
if df['ptime'].dtype == 'datetime64[ns]':
|
|
|
|
|
|
|
|
print("ptime列已经是以日期时间格式存储。")
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
# 尝试将'ptime'列转换为datetime类型
|
|
|
|
|
|
|
|
df['ptime'] = pd.to_datetime(df['ptime'], format='%m/%d/%Y')
|
|
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
|
|
logging.error("ptime列转换为datetime类型失败,可能是因为格式不正确。")
|
|
|
|
|
|
|
|
return None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置'ptime'为索引
|
|
|
|
|
|
|
|
#df.set_index('ptime', inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 确定features列表
|
|
|
|
|
|
|
|
if target_col in df.columns:
|
|
|
|
|
|
|
|
features = df.columns.drop(target_col)
|
|
|
|
|
|
|
|
print("features:", features)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
logging.warning(f"目标列 '{target_col}' 在DataFrame中未找到,将不进行列删除操作。")
|
|
|
|
|
|
|
|
features = df.columns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查features是否被正确设置
|
|
|
|
|
|
|
|
if features is None:
|
|
|
|
|
|
|
|
logging.error("未找到任何特征列。")
|
|
|
|
|
|
|
|
return None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("@@@@")
|
|
|
|
|
|
|
|
print(target_col)
|
|
|
|
|
|
|
|
print("@@@@")
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
df.set_index('ptime', inplace=True)
|
|
|
|
|
|
|
|
except KeyError:
|
|
|
|
|
|
|
|
print("列 'ptime' 不存在,无法设置为索引。")
|
|
|
|
|
|
|
|
# 在这里处理缺少'ptime'的情况,比如跳过相关操作或使用其他列
|
|
|
|
|
|
|
|
# 使用MinMaxScaler进行特征缩放
|
|
|
|
|
|
|
|
scaler = MinMaxScaler()
|
|
|
|
|
|
|
|
scaled_features = scaler.fit_transform(df[features])
|
|
|
|
|
|
|
|
scaled_target = scaler.fit_transform(df.index.values.reshape(-1, 1))
|
|
|
|
|
|
|
|
print("~~~")
|
|
|
|
|
|
|
|
return scaled_features, scaled_target, scaler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_dataset_into_train_test(features, target, test_size=0.2):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
切分数据集为训练集和测试集。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 检查features和target的类型以及长度是否相等
|
|
|
|
|
|
|
|
if not isinstance(features, np.ndarray) or not isinstance(target, np.ndarray):
|
|
|
|
|
|
|
|
raise TypeError("features and target must be numpy arrays")
|
|
|
|
|
|
|
|
if len(features) != len(target):
|
|
|
|
|
|
|
|
raise ValueError("features and target must have the same length")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查test_size是否在合理的范围内
|
|
|
|
|
|
|
|
if not 0 < test_size < 1:
|
|
|
|
|
|
|
|
raise ValueError("test_size must be between 0 and 1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 计算训练集大小
|
|
|
|
|
|
|
|
train_size = int(len(features) * (1 - test_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 使用numpy的切片操作,这不会创建新的数据副本,提高性能
|
|
|
|
|
|
|
|
train_features, test_features = features[:train_size], features[train_size:]
|
|
|
|
|
|
|
|
train_target, test_target = target[:train_size], target[train_size:]
|
|
|
|
|
|
|
|
print("123456")
|
|
|
|
|
|
|
|
print(features)
|
|
|
|
|
|
|
|
print(target)
|
|
|
|
|
|
|
|
print(train_features)
|
|
|
|
|
|
|
|
print(train_target)
|
|
|
|
|
|
|
|
print(test_features)
|
|
|
|
|
|
|
|
print(test_target)
|
|
|
|
|
|
|
|
return train_features, test_features, train_target, test_target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_data_shapes(train_features, test_features, n_steps):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
验证训练和测试数据形状是否符合预期。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if train_features.shape[1] != n_steps or test_features.shape[1] != n_steps:
|
|
|
|
|
|
|
|
raise ValueError(f"训练和测试特征的第二维度(时间步长)应为{n_steps}")
|
|
|
|
|
|
|
|
print("7890")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_model(n_steps, lstm_units, dense_units, input_shape):
|
|
|
|
|
|
|
|
inputs = Input(shape=input_shape) # 添加Input对象
|
|
|
|
|
|
|
|
x = LSTM(lstm_units)(inputs) # 直接将Input对象传递给LSTM层
|
|
|
|
|
|
|
|
outputs = Dense(dense_units)(x)
|
|
|
|
|
|
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
|
|
|
|
|
|
|
model.compile(optimizer='adam', loss='mse')
|
|
|
|
|
|
|
|
print("!!!")
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_params(epochs, batch_size):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
确保 epochs 和 batch_size 是合法的参数。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not isinstance(epochs, int) or epochs <= 0:
|
|
|
|
|
|
|
|
raise ValueError("epochs 应该是一个正整数")
|
|
|
|
|
|
|
|
if not isinstance(batch_size, int) or batch_size <= 0:
|
|
|
|
|
|
|
|
raise ValueError("batch_size 应该是一个正整数")
|
|
|
|
|
|
|
|
if epochs <= 0 or batch_size <= 0:
|
|
|
|
|
|
|
|
raise ValueError("epochs和batch_size必须大于0")
|
|
|
|
|
|
|
|
print("%%%")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_directory_safety(path:str):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
确保路径安全且存在。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not os.path.isabs(path):
|
|
|
|
|
|
|
|
raise ValueError("路径应该是绝对路径")
|
|
|
|
|
|
|
|
directory = os.path.dirname(path)
|
|
|
|
|
|
|
|
print(directory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
# 检查目录是否需要创建
|
|
|
|
|
|
|
|
if not os.path.exists(directory):
|
|
|
|
|
|
|
|
# 添加日志记录
|
|
|
|
|
|
|
|
logger.info(f"目录 {directory} 不存在,开始创建。")
|
|
|
|
|
|
|
|
# 使用 exist_ok=True 避免在目录已存在时抛出异常
|
|
|
|
|
|
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
|
|
|
|
logger.info(f"目录 {directory} 创建成功。")
|
|
|
|
|
|
|
|
except PermissionError:
|
|
|
|
|
|
|
|
# 捕获权限异常,给出清晰的错误提示
|
|
|
|
|
|
|
|
logger.error(f"没有权限在 {directory} 创建目录。")
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
# 捕获其他异常,记录并抛出
|
|
|
|
|
|
|
|
logger.error(f"创建目录 {directory} 时发生未知错误:{e}")
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
print("===")
|
|
|
|
|
|
|
|
def train_model(model: Model, train_features, train_target, test_features, test_target, epochs: int, batch_size: int,
|
|
|
|
|
|
|
|
patience: int, save_best_only: bool = True, monitor: str = 'val_loss', mode: str = 'min',
|
|
|
|
|
|
|
|
model_path: str = "best_model.h5") -> dict:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
训练模型,并根据早停策略和性能指标保存最佳模型。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param model: Keras模型实例
|
|
|
|
|
|
|
|
:param train_features: 训练特征
|
|
|
|
|
|
|
|
:param train_target: 训练目标
|
|
|
|
|
|
|
|
:param test_features: 测试特征
|
|
|
|
|
|
|
|
:param test_target: 测试目标
|
|
|
|
|
|
|
|
:param epochs: 训练轮数
|
|
|
|
|
|
|
|
:param batch_size: 批量大小
|
|
|
|
|
|
|
|
:param patience: 早停策略的耐心值
|
|
|
|
|
|
|
|
:param save_best_only: 是否只保存最佳模型
|
|
|
|
|
|
|
|
:param monitor: 监控的指标
|
|
|
|
|
|
|
|
:param mode: 监控指标的模式(min/max)
|
|
|
|
|
|
|
|
:param model_path: 模型保存路径
|
|
|
|
|
|
|
|
:return: 训练历史记录
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = "/path/to/your/model.h5"
|
|
|
|
|
|
|
|
validate_params(epochs, batch_size)
|
|
|
|
|
|
|
|
ensure_directory_safety(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 使用ModelCheckpoint保存最佳模型
|
|
|
|
|
|
|
|
filepath = model_path
|
|
|
|
|
|
|
|
checkpoint = ModelCheckpoint(filepath, monitor=monitor, verbose=1, save_best_only=save_best_only, mode=mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义早停策略
|
|
|
|
|
|
|
|
early_stopping = EarlyStopping(monitor=monitor, patience=patience, verbose=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
history = model.fit(train_features, train_target, epochs=epochs, batch_size=batch_size,
|
|
|
|
|
|
|
|
validation_data=(test_features, test_target), verbose=1,
|
|
|
|
|
|
|
|
callbacks=[early_stopping, checkpoint])
|
|
|
|
|
|
|
|
logging.info("###")
|
|
|
|
|
|
|
|
return history
|
|
|
|
|
|
|
|
except ValueError as ve:
|
|
|
|
|
|
|
|
logging.error(f"参数错误: {ve}")
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
except OSError as oe:
|
|
|
|
|
|
|
|
logging.error(f"文件操作错误: {oe}")
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
logging.error(f"模型训练过程中发生异常: {e}")
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_and_train_model(n_steps, features, target, train_features, train_target, test_features, test_target,lstm_units=50, dense_units=1, optimizer='adam', loss='mse', epochs=100, batch_size=32,patience=10, model_save_path='model.h5'):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
构建LSTM模型并进行训练,增加了参数可配置性,早停策略和模型保存。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 输入数据验证
|
|
|
|
|
|
|
|
if not (isinstance(train_features, np.ndarray) and isinstance(train_target, np.ndarray) and isinstance(test_features, np.ndarray) and isinstance(test_target, np.ndarray)):
|
|
|
|
|
|
|
|
raise ValueError("输入数据train_features, train_target, test_features, test_target必须是numpy数组")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint = ModelCheckpoint(filepath="/path/to/your/model.keras", # 注意这里的路径保持为.h5
|
|
|
|
|
|
|
|
monitor='val_loss', # 或您希望监控的指标
|
|
|
|
|
|
|
|
verbose=1,
|
|
|
|
|
|
|
|
save_best_only=True,
|
|
|
|
|
|
|
|
mode='min')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 数据形状验证
|
|
|
|
|
|
|
|
validate_data_shapes(train_features, test_features, n_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = build_model(n_steps, lstm_units, dense_units, input_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 早停策略
|
|
|
|
|
|
|
|
early_stopping = EarlyStopping(monitor='val_loss', patience=patience, verbose=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = model.fit(train_features, train_target,
|
|
|
|
|
|
|
|
validation_data=(test_features, test_target),
|
|
|
|
|
|
|
|
epochs=epochs,
|
|
|
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
|
|
|
callbacks=[checkpoint],
|
|
|
|
|
|
|
|
# 其他参数...
|
|
|
|
|
|
|
|
) # 模型保存
|
|
|
|
|
|
|
|
# 增加了路径验证来防止潜在的安全问题,这里简化处理,实际应用中可能需要更复杂的逻辑
|
|
|
|
|
|
|
|
if not model_save_path.endswith('.h5'):
|
|
|
|
|
|
|
|
model_save_path += '.h5'
|
|
|
|
|
|
|
|
model.save(model_save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_model(model, scaler, test_target, predictions):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
评估模型性能并反向转换预测结果。
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
predictions = scaler.inverse_transform(predictions)
|
|
|
|
|
|
|
|
test_target_inv = scaler.inverse_transform(test_target.reshape(-1, 1))
|
|
|
|
|
|
|
|
mse = mean_squared_error(test_target_inv, predictions)
|
|
|
|
|
|
|
|
print(f'Mean Squared Error: {mse}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
engine = create_database_engine(get_db_config())
|
|
|
|
|
|
|
|
query = "SELECT ptime, ci FROM may"
|
|
|
|
|
|
|
|
df = fetch_data(engine, query)
|
|
|
|
|
|
|
|
target_col = 'ptime'
|
|
|
|
|
|
|
|
features, target, scaler = preprocess_data(df, target_col)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_features, test_features, train_target, test_target = split_dataset_into_train_test(features, target, test_size=0.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_steps = 5
|
|
|
|
|
|
|
|
# 假设train_features和test_features是你的数据,且它们是二维数组
|
|
|
|
|
|
|
|
# 首先,你需要获取或设定一个maxlen,这里假设我们已知或计算出它应该是5
|
|
|
|
|
|
|
|
maxlen = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 对训练数据进行填充或截断
|
|
|
|
|
|
|
|
train_features_padded = pad_sequences(train_features, maxlen=maxlen, padding='post', truncating='post')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 对测试数据进行同样的处理
|
|
|
|
|
|
|
|
test_features_padded = pad_sequences(test_features, maxlen=maxlen, padding='post', truncating='post')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_shape = (n_steps, int(train_features.shape[1]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model, history = build_and_train_model(n_steps=n_steps,
|
|
|
|
|
|
|
|
features=features,
|
|
|
|
|
|
|
|
target=target,
|
|
|
|
|
|
|
|
train_target=train_target,
|
|
|
|
|
|
|
|
test_target=test_target,
|
|
|
|
|
|
|
|
train_features=train_features_padded,
|
|
|
|
|
|
|
|
test_features=test_features_padded)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictions = model.predict(test_features)
|
|
|
|
|
|
|
|
mse = evaluate_model(model, scaler, test_target, predictions)
|
|
|
|
|
|
|
|
# 可视化预测结果(可选)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#plt.plot(test_target, label='Actual')
|
|
|
|
|
|
|
|
plt.plot(predictions, label='Predicted')
|
|
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
plt.xlabel('Ptime')
|
|
|
|
|
|
|
|
plt.ylabel('CI')
|
|
|
|
|
|
|
|
#plt.plot(ptime, ci)
|
|
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保存模型
|
|
|
|
|
|
|
|
model.save('trained_model.h5')
|
|
|
|
|
|
|
|
joblib.dump(scaler, 'scaler.joblib')
|