import json import os from django.http import HttpResponse from django.shortcuts import render, get_object_or_404 from LSTMPredictStock import run from stock_predict import models from datetime import datetime as dt from apscheduler.scheduler import Scheduler from .models import Company import pandas as pd LOCAL = False def get_hist_predict_data(stock_code): recent_data,predict_data = None,None # company = models.Company.objects.get(stock_code=stock_code) company = get_object_or_404(Company, stock_code=stock_code) if company.historydata_set.count() <= 0: history_data = models.HistoryData() history_data.company = company history_data.set_data(run.get_hist_data(stock_code=stock_code,recent_day=20)) history_data.save() recent_data = history_data.get_data() else: all_data = company.historydata_set.all() for single in all_data: now = dt.now() end_date = single.get_data()[-1][0] end_date = dt.strptime(end_date,"%Y-%m-%d") if LOCAL & (now.date() > end_date.date()): # 更新预测数据 single.set_data(run.get_hist_data(stock_code=stock_code,recent_day=20)) single.save() recent_data = single.get_data() break if company.predictdata_set.count() <= 0: predict_data = models.PredictData() predict_data.company = company predict_data.set_data(run.prediction(stock_code,pre_len=10)) predict_data.save() predict_data = predict_data.get_data() else: all_data = company.predictdata_set.all() for single in all_data: now = dt.now() start_date = dt.strptime(single.start_date,"%Y-%m-%d") if LOCAL & (now.date() > start_date.date()): # 更新预测数据 single.set_data(run.prediction(stock_code, pre_len=10)) single.save() predict_data = single.get_data() break return recent_data,predict_data def get_crawl_save_data(): """ 将10个公司的指标数据爬取并保存到数据库 """ # 此处应是从网上爬取数据,并保存为csv文件 parent_dir = os.path.dirname(__file__) # "stock_predict/views.py" file_dir = os.path.join(parent_dir, "stock_index/") for file_name in os.listdir(file_dir): file_path = os.path.join(file_dir, file_name) data_frame = pd.read_csv(file_path) stock_code = file_name.split('.')[0] company = get_object_or_404(Company, stock_code=stock_code) for index,row in data_frame.iterrows(): company.stockindex_set.create(ri_qi=row['ri_qi'],zi_jin=row['zi_jin'],qiang_du=row['qiang_du'],feng_xian=row['feng_xian'], zhuan_qiang=row['zhuan_qiang'],chang_yu=row['chang_yu'],jin_zi=row['jin_zi'],zong_he=row['zong_he']) def get_stock_index(stock_code): """ 获取股票的各项指标数据 """ company = get_object_or_404(Company, stock_code=stock_code) if company.stockindex_set.count() <= 0: # 将爬取的数据存入数据库 get_crawl_save_data() # 从数据库获取近三天的数据 indexs = company.stockindex_set.all().order_by('-ri_qi')[:3].values() return list(indexs) def home(request): recent_data,predict_data = get_hist_predict_data("600718") data = {"recent_data":recent_data,"stock_code":"600718","predict_data":predict_data} data['indexs'] = get_stock_index("600718") return render(request,"stock_predict/home.html",{"data":json.dumps(data)}) # json.dumps(list) def predict_stock_action(request): stock_code = request.POST.get('stock_code',None) # print("stock_code:\n",stock_code) recent_data, predict_data = get_hist_predict_data(stock_code) data = {"recent_data": recent_data, "stock_code": stock_code, "predict_data": predict_data} data['indexs'] = get_stock_index(stock_code) return render(request, "stock_predict/home.html", {"data": json.dumps(data)}) # json.dumps(list) sched = Scheduler() # 定时任务 # @sched.interval_schedule(seconds=2) # 每2s执行一次 @sched.cron_schedule(hour=0,minute=0) # 每日凌晨调度一次 def train_models(): run.train_all_stock() sched.start()