Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

import keras 

 

__author__ = "Jakob Aungiers" 

__copyright__ = "Jakob Aungiers 2018" 

__version__ = "2.0.0" 

__license__ = "MIT" 

 

import os 

import json 

import numpy as np 

import math 

import pandas as pd 

import matplotlib.pyplot as plt 

from LSTMPredictStock.core.data_processor import DataLoader 

from LSTMPredictStock.core.model import Model 

from datetime import datetime,timedelta 

from LSTMPredictStock.core.get_domestic_hist_stock import get_all_last_data 

from LSTMPredictStock.core.get_domestic_hist_stock import get_single_last_data 

 

 

def plot_results(predicted_data, true_data): # predicted_data与true_data:同长度一维数组 

fig = plt.figure(facecolor='white') 

ax = fig.add_subplot(111) 

ax.plot(true_data, label='True Data') 

plt.plot(predicted_data, label='Prediction') 

plt.legend() 

plt.show() 

 

 

# predicted_data每个元素的长度必须为prediction_len 

def plot_results_multiple(predicted_data, true_data, prediction_len): 

fig = plt.figure(facecolor='white') 

ax = fig.add_subplot(111) 

ax.plot(true_data, label='True Data') 

# Pad the list of predictions to shift it in the graph to it's correct start 

for i, data in enumerate(predicted_data): # data为一维数组,长度为prediction_len。predicted_data:二维数组,每个元素为list 

padding = [None for p in range(i * prediction_len)] 

plt.plot(padding + data, label='Prediction') # padding + data:list拼接操作 

plt.legend() 

plt.show() 

 

 

# 只用于训练模型,但同时可根据参数进行模型的评估 

def train_model(stock_code, predict=False): # 训练指定股票代码的模型 

''' 

训练并保存模型,同时根据测试数据对模型进行评估(绘图方式) 

''' 

 

configs = json.load(open(get_config_path(), 'r')) 

if not os.path.exists(os.path.join(get_parent_dir(),configs['model']['save_dir'])): 

os.makedirs(os.path.join(get_parent_dir(),configs['model']['save_dir'])) # 创建保存模型的目录 

 

split = configs['data']['train_test_split'] 

if not predict: 

split = 1 # 若不评估模型准确度,则将全部历史数据用于训练 

 

data = DataLoader( # 从本地加载训练和测试数据 

os.path.join(get_parent_dir(),os.path.join('data', stock_code + ".csv")), # configs['data']['filename'] 

split, 

configs['data']['columns'] # 选择某些列的数据进行训练 

) 

 

model = Model() 

model.build_model(configs) # 根据配置文件新建模型 

 

''' 

# in-memory training 

model.train( 

x, 

y, 

epochs = configs['training']['epochs'], 

batch_size = configs['training']['batch_size'], 

save_dir = configs['model']['save_dir'] 

) 

''' 

# 训练模型: 

# out-of memory generative training 

steps_per_epoch = math.ceil( 

(data.len_train - configs['data']['sequence_length']) / configs['training']['batch_size']) 

model.train_generator( 

data_gen=data.generate_train_batch( 

seq_len=configs['data']['sequence_length'], 

batch_size=configs['training']['batch_size'], 

normalise=configs['data']['normalise'] 

), 

epochs=configs['training']['epochs'], 

batch_size=configs['training']['batch_size'], 

steps_per_epoch=steps_per_epoch, 

save_dir=os.path.join(get_parent_dir(),configs['model']['save_dir']), 

save_name=stock_code 

) 

 

# 预测 

if predict: 

x_test, y_test = data.get_test_data( 

seq_len=configs['data']['sequence_length'], 

normalise=configs['data']['normalise'] 

) 

 

predictions = model.predict_sequences_multiple(x_test, configs['data']['sequence_length'], 

configs['data']['sequence_length']) 

print("训练:\n", predictions) 

# plot_results_multiple(predictions, y_test, configs['data']['sequence_length']) 

 

 

# 对指定公司的股票进行预测 

def prediction(stock_code, real=True, pre_len=30, plot=False): 

''' 

使用保存的模型,对输入数据进行预测 

''' 

config_path = get_config_path() 

configs = json.load(open(config_path, 'r')) 

data = DataLoader( 

os.path.join(get_data_path(), stock_code + ".csv"), # configs['data']['filename'] 

configs['data']['train_test_split'], 

configs['data']['columns'] 

) 

 

file_path = os.path.join(get_parent_dir(),os.path.join("saved_models",stock_code + ".h5")) 

model = Model() 

keras.backend.clear_session() 

model.load_model(file_path) # 根据配置文件新建模型 

 

# predict_length = configs['data']['sequence_length'] # 预测长度 

predict_length = pre_len 

if real: # 用最近一个窗口的数据进行预测,没有对比数据 

win_position = -1 

else: # 用指定位置的一个窗口数据进行预测,有对比真实数据(用于绘图对比) 

win_position = -configs['data']['sequence_length'] 

 

x_test, y_test = data.get_test_data( 

seq_len=configs['data']['sequence_length'], 

normalise=False 

) 

 

x_test = x_test[win_position] 

x_test = x_test[np.newaxis, :, :] 

if not real: 

y_test_real = y_test[win_position:win_position + predict_length] 

 

base = x_test[0][0][0] 

print("base value:\n", base) 

 

x_test, y_test = data.get_test_data( 

seq_len=configs['data']['sequence_length'], 

normalise=configs['data']['normalise'] 

) 

x_test = x_test[win_position] 

x_test = x_test[np.newaxis, :, :] 

 

# predictions = model.predict_sequences_multiple(x_test, configs['data']['sequence_length'], 

# predict_length) 

 

predictions = model.predict_1_win_sequence(x_test, configs['data']['sequence_length'], predict_length) 

# 反归一化 

predictions_array = np.array(predictions) 

predictions_array = base * (1 + predictions_array) 

predictions = predictions_array.tolist() 

 

# print("预测数据:\n", predictions) 

if not real: 

print("真实数据:\n", y_test_real) 

 

# plot_results_multiple(predictions, y_test, predict_length) 

if plot: 

if real: 

plot_results(predictions, []) 

else: 

plot_results(predictions, y_test_real) 

 

return format_predictions(predictions) 

 

def format_predictions(predictions): # 给预测数据添加对应日期 

date_predict = [] 

cur = datetime.now() 

cur += timedelta(days=1) 

counter = 0 

 

while counter < len(predictions): 

if cur.isoweekday() == 6: 

cur = cur + timedelta(days=2) 

if cur.isoweekday() == 7: 

cur = cur + timedelta(days=1) 

date_predict.append([cur.strftime("%Y-%m-%d"),predictions[counter]]) 

cur = cur + timedelta(days=1) 

counter += 1 

 

return date_predict 

 

''' 

def main(stock_code, train=False, predict=False): 

configs = json.load(open(get_config_path(), 'r')) 

companies = configs['companies'] 

 

if stock_code not in companies.keys(): 

print("该公司不在指定范围内") 

return -1 

 

if train: 

train_model(stock_code) 

return 0 # 训练完成 

 

if predict: 

# for root, dirs, files in os.walk('saved_models'): 

# root:当前目录路径 dirs: 当前路径下所有子目录 files:当前路径下所有非目录子文件 

if stock_code + ".h5" in os.listdir("saved_models"): # os.listdir:获得当前目录下的所有文件名。不包括子目录 

return prediction(stock_code=stock_code, real=True, pre_len=20) 

else: 

return -2 # 该公司还没有训练模型 

''' 

# 二维数组:[[data,value],[...]] 

def get_hist_data(stock_code, recent_day=30): # 获取某股票,指定天数的历史close数据,包含日期 

get_single_last_data(stock_code) 

root_dir = get_parent_dir() 

file_path = os.path.join(root_dir, "data/" + stock_code + ".csv") 

cols = ['Date', 'Close'] 

data_frame = pd.read_csv(file_path) 

close_data = data_frame.get(cols).values[-recent_day:] 

return close_data.tolist() 

 

 

def train_all_stock(): # 

get_all_last_data(start_date="2010-01-01") 

configs = json.load(open(get_config_path(), 'r')) 

companies = configs['companies'] 

for stock_code in companies.keys(): 

train_model(stock_code) 

 

return 0 

 

 

def predict_all_stock(pre_len=10): 

file_path = get_config_path() 

configs = json.load(open(file_path, 'r')) 

companies = configs['companies'] 

predict_list = [] 

for stock_code in companies.keys(): 

predict_list.append(prediction(stock_code=stock_code, real=True, pre_len=pre_len)) 

 

return predict_list 

 

 

def get_config_path(): # config.json的绝对路径 

root_dir = get_parent_dir() 

return os.path.join(root_dir, "config.json") 

 

 

def get_data_path(): # data目录的绝对路径 

root_dir = get_parent_dir() 

return os.path.join(root_dir, "data") 

 

 

def get_parent_dir(): # 当前文件的父目录绝对路径 

return os.path.dirname(__file__) 

 

 

# if __name__ == '__main__': 

# # get_all_last_data("2010-01-01") # 先获得最新数据 

# train_all_stock() 

# # predict_all_stock()