import json
import os
from django.http import HttpResponse
from django.shortcuts import render, get_object_or_404
from LSTMPredictStock import run
from stock_predict import models
from datetime import datetime as dt
from apscheduler.scheduler import Scheduler
from .models import Company
import pandas as pd
# LOCAL = False
def get_hist_predict_data(stock_code):
recent_data,predict_data = None,None
# company = models.Company.objects.get(stock_code=stock_code)
company = get_object_or_404(Company, stock_code=stock_code)
if company.historydata_set.count() <= 0:
history_data = models.HistoryData()
history_data.company = company
history_data.set_data(run.get_hist_data(stock_code=stock_code,recent_day=20))
history_data.save()
recent_data = history_data.get_data()
else:
all_data = company.historydata_set.all()
for single in all_data:
now = dt.now()
end_date = single.get_data()[-1][0]
end_date = dt.strptime(end_date,"%Y-%m-%d")
if now.date() > end_date.date(): # 更新预测数据
single.set_data(run.get_hist_data(stock_code=stock_code,recent_day=20))
single.save()
recent_data = single.get_data()
break
if company.predictdata_set.count() <= 0:
predict_data = models.PredictData()
predict_data.company = company
predict_data.set_data(run.prediction(stock_code,pre_len=10))
predict_data.save()
predict_data = predict_data.get_data()
else:
all_data = company.predictdata_set.all()
for single in all_data:
now = dt.now()
start_date = dt.strptime(single.start_date,"%Y-%m-%d")
if now.date() > start_date.date(): # 更新预测数据
single.set_data(run.prediction(stock_code, pre_len=10))
single.save()
predict_data = single.get_data()
break
return recent_data,predict_data
def get_crawl_save_data():
"""
将10个公司的指标数据爬取并保存到数据库
"""
# 此处应是从网上爬取数据,并保存为csv文件
parent_dir = os.path.dirname(__file__) # "stock_predict/views.py"
file_dir = os.path.join(parent_dir, "stock_index/")
for file_name in os.listdir(file_dir):
file_path = os.path.join(file_dir, file_name)
data_frame = pd.read_csv(file_path)
stock_code = file_name.split('.')[0]
company = get_object_or_404(Company, stock_code=stock_code)
for index,row in data_frame.iterrows():
company.stockindex_set.create(ri_qi=row['ri_qi'],zi_jin=row['zi_jin'],qiang_du=row['qiang_du'],feng_xian=row['feng_xian'],
zhuan_qiang=row['zhuan_qiang'],chang_yu=row['chang_yu'],jin_zi=row['jin_zi'],zong_he=row['zong_he'])
def get_stock_index(stock_code):
"""
获取股票的各项指标数据
"""
company = get_object_or_404(Company, stock_code=stock_code)
if company.stockindex_set.count() <= 0:
# 将爬取的数据存入数据库
get_crawl_save_data()
# 从数据库获取近三天的数据
indexs = company.stockindex_set.all().order_by('-ri_qi')[:3].values()
return list(indexs)
def home(request):
recent_data,predict_data = get_hist_predict_data("600718")
data = {"recent_data":recent_data,"stock_code":"600718","predict_data":predict_data}
data['indexs'] = get_stock_index("600718")
return render(request,"stock_predict/home.html",{"data":json.dumps(data)}) # json.dumps(list)
def predict_stock_action(request):
stock_code = request.POST.get('stock_code',None)
# print("stock_code:\n",stock_code)
recent_data, predict_data = get_hist_predict_data(stock_code)
data = {"recent_data": recent_data, "stock_code": stock_code, "predict_data": predict_data}
data['indexs'] = get_stock_index(stock_code)
return render(request, "stock_predict/home.html", {"data": json.dumps(data)}) # json.dumps(list)
'''
sched = Scheduler()
# 定时任务
# @sched.interval_schedule(seconds=2) # 每2s执行一次
@sched.cron_schedule(hour=0,minute=0) # 每日凌晨调度一次
def train_models():
run.train_all_stock()
sched.start()
'''
|