Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
fig = plt.figure(facecolor='white') ax = fig.add_subplot(111) ax.plot(true_data, label='True Data') plt.plot(predicted_data, label='Prediction') plt.legend() plt.show()
# predicted_data每个元素的长度必须为prediction_len fig = plt.figure(facecolor='white') ax = fig.add_subplot(111) ax.plot(true_data, label='True Data') # Pad the list of predictions to shift it in the graph to it's correct start for i, data in enumerate(predicted_data): # data为一维数组,长度为prediction_len。predicted_data:二维数组,每个元素为list padding = [None for p in range(i * prediction_len)] plt.plot(padding + data, label='Prediction') # padding + data:list拼接操作 plt.legend() plt.show()
# 只用于训练模型,但同时可根据参数进行模型的评估 ''' 训练并保存模型,同时根据测试数据对模型进行评估(绘图方式) '''
os.makedirs(os.path.join(get_parent_dir(),configs['model']['save_dir'])) # 创建保存模型的目录
os.path.join(get_parent_dir(),os.path.join('data', stock_code + ".csv")), # configs['data']['filename'] split, configs['data']['columns'] # 选择某些列的数据进行训练 )
''' # in-memory training model.train( x, y, epochs = configs['training']['epochs'], batch_size = configs['training']['batch_size'], save_dir = configs['model']['save_dir'] ) ''' # 训练模型: # out-of memory generative training (data.len_train - configs['data']['sequence_length']) / configs['training']['batch_size']) data_gen=data.generate_train_batch( seq_len=configs['data']['sequence_length'], batch_size=configs['training']['batch_size'], normalise=configs['data']['normalise'] ), epochs=configs['training']['epochs'], batch_size=configs['training']['batch_size'], steps_per_epoch=steps_per_epoch, save_dir=os.path.join(get_parent_dir(),configs['model']['save_dir']), save_name=stock_code )
# 预测 x_test, y_test = data.get_test_data( seq_len=configs['data']['sequence_length'], normalise=configs['data']['normalise'] )
predictions = model.predict_sequences_multiple(x_test, configs['data']['sequence_length'], configs['data']['sequence_length']) print("训练:\n", predictions) # plot_results_multiple(predictions, y_test, configs['data']['sequence_length'])
# 对指定公司的股票进行预测 ''' 使用保存的模型,对输入数据进行预测 ''' os.path.join(get_data_path(), stock_code + ".csv"), # configs['data']['filename'] configs['data']['train_test_split'], configs['data']['columns'] )
# predict_length = configs['data']['sequence_length'] # 预测长度 else: # 用指定位置的一个窗口数据进行预测,有对比真实数据(用于绘图对比) win_position = -configs['data']['sequence_length']
seq_len=configs['data']['sequence_length'], normalise=False )
y_test_real = y_test[win_position:win_position + predict_length]
seq_len=configs['data']['sequence_length'], normalise=configs['data']['normalise'] )
# predictions = model.predict_sequences_multiple(x_test, configs['data']['sequence_length'], # predict_length)
# 反归一化
# print("预测数据:\n", predictions) print("真实数据:\n", y_test_real)
# plot_results_multiple(predictions, y_test, predict_length) if real: plot_results(predictions, []) else: plot_results(predictions, y_test_real)
''' def main(stock_code, train=False, predict=False): configs = json.load(open(get_config_path(), 'r')) companies = configs['companies']
if stock_code not in companies.keys(): print("该公司不在指定范围内") return -1
if train: train_model(stock_code) return 0 # 训练完成
if predict: # for root, dirs, files in os.walk('saved_models'): # root:当前目录路径 dirs: 当前路径下所有子目录 files:当前路径下所有非目录子文件 if stock_code + ".h5" in os.listdir("saved_models"): # os.listdir:获得当前目录下的所有文件名。不包括子目录 return prediction(stock_code=stock_code, real=True, pre_len=20) else: return -2 # 该公司还没有训练模型 ''' # 二维数组:[[data,value],[...]]
# if __name__ == '__main__': # # get_all_last_data("2010-01-01") # 先获得最新数据 # train_all_stock() # # predict_all_stock() |