You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
5.3 KiB
136 lines
5.3 KiB
# 从 sqlalchemy 导入 func 函数,用于在数据库查询中使用函数,如获取当前时间
|
|
from sqlalchemy import func
|
|
# 从 app.models 模块导入数据库实例
|
|
from app.models import db
|
|
# 从 app.models.train_station_lib 模块导入 TrainStation 模型
|
|
from app.models.train_station_lib import TrainStation
|
|
|
|
class Train(db.Model):
|
|
"""
|
|
定义 Train 模型,对应数据库中的 'train' 表,用于存储列车信息。
|
|
"""
|
|
__tablename__: str = 'train'
|
|
|
|
# 定义列车 ID 列,作为主键,自增整数类型
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
# 定义列车编号列,最大长度为 120 的字符串类型,唯一且不能为空,添加索引以提高查询效率
|
|
train_no = db.Column(db.String(120), unique=True, nullable=False, index=True)
|
|
# 定义出发站列,最大长度为 120 的字符串类型
|
|
departure_station = db.Column(db.String(120))
|
|
# 定义到达站列,最大长度为 120 的字符串类型
|
|
arrival_station = db.Column(db.String(120))
|
|
# 定义出发时间列,日期时间类型
|
|
departure_time = db.Column(db.DateTime)
|
|
# 定义过期时间列,日期时间类型
|
|
expiration_time = db.Column(db.DateTime)
|
|
# 定义生效时间列,日期时间类型
|
|
effective_time = db.Column(db.DateTime)
|
|
# 定义到达时间列,日期时间类型
|
|
arrival_time = db.Column(db.DateTime)
|
|
# 定义列车记录创建时间列,日期时间类型,默认值为当前时间
|
|
created_at = db.Column(db.DateTime, default=func.now())
|
|
# 定义列车记录更新时间列,日期时间类型,默认值为当前时间
|
|
updated_at = db.Column(db.DateTime, default=func.now())
|
|
|
|
def __repr__(self):
|
|
"""
|
|
返回一个可打印的字符串表示该列车对象,方便调试和日志记录。
|
|
|
|
:return: 包含列车 ID 的字符串
|
|
"""
|
|
return f'<Train {self.id}>'
|
|
|
|
@classmethod
|
|
def create(cls, new_train):
|
|
"""
|
|
将新的列车对象添加到数据库并提交更改。
|
|
|
|
:param new_train: 要创建的列车对象
|
|
:return: 创建成功的列车对象
|
|
"""
|
|
# 将新列车对象添加到数据库会话
|
|
db.session.add(new_train)
|
|
# 提交数据库会话,将更改保存到数据库
|
|
db.session.commit()
|
|
return new_train
|
|
|
|
@classmethod
|
|
def queryTrains(cls, from_station, to_station, date):
|
|
"""
|
|
根据出发站、到达站和日期查询符合条件的列车。
|
|
|
|
:param from_station: 出发站名称
|
|
:param to_station: 到达站名称
|
|
:param date: 查询日期
|
|
:return: 符合条件的列车列表
|
|
"""
|
|
# 查询出发站名称匹配 `from_station` 的所有列车站点记录
|
|
from_train = TrainStation.query.filter_by(station_name=from_station).all()
|
|
|
|
# 查询到达站名称匹配 `to_station` 的所有列车站点记录
|
|
to_train = TrainStation.query.filter_by(station_name=to_station).all()
|
|
|
|
# 从出发站查询结果中提取列车编号,存储在集合中以确保唯一性
|
|
from_train_nos = {ts.train_no for ts in from_train}
|
|
|
|
# 从到达站查询结果中提取列车编号,存储在集合中以确保唯一性
|
|
to_train_nos = {ts.train_no for ts in to_train}
|
|
|
|
# 找出两个集合中共同的列车编号
|
|
common_train_nos = from_train_nos & to_train_nos
|
|
|
|
# 过滤出出发站索引小于到达站索引的列车编号
|
|
valid_train_nos = [
|
|
train_no for train_no in common_train_nos
|
|
if next(ts.index for ts in from_train if ts.train_no == train_no) <
|
|
next(ts.index for ts in to_train if ts.train_no == train_no)
|
|
]
|
|
# 根据过滤后的列车编号和给定日期查询列车信息
|
|
trains = Train.query.filter(
|
|
Train.effective_time >= date,
|
|
Train.train_no.in_(valid_train_nos)
|
|
).all()
|
|
|
|
# 假设存在一个用于处理列车数据的展示器或序列化器
|
|
return trains
|
|
|
|
|
|
def buildTrain(params):
|
|
"""
|
|
根据传入的参数构建一个新的列车对象,并关联相应的列车站点信息。
|
|
|
|
:param params: 包含列车和站点信息的字典
|
|
:return: 构建好的列车对象
|
|
"""
|
|
# 创建一个新的 Train 对象
|
|
train = Train(
|
|
train_no=params['trainNo'],
|
|
effective_time=params['effective_time'],
|
|
expiration_time=params['expiration_time']
|
|
)
|
|
|
|
# 从站点信息中提取索引值
|
|
indexes = [e["index"] for e in params['stations']]
|
|
|
|
for e in params['stations']:
|
|
# 创建并关联 TrainStation 对象
|
|
train_station = TrainStation(
|
|
station_name=e["name"],
|
|
price=e["price"],
|
|
departure_time=e["depTime"],
|
|
arrival_time=e["arrTime"],
|
|
index=e["index"]
|
|
)
|
|
# 将列车站点对象添加到列车的站点列表中
|
|
train.train_stations.append(train_station)
|
|
# 确定第一站的出发时间和站点
|
|
if e["index"] == 0:
|
|
train.departure_time = e["depTime"]
|
|
train.departure_station = e["name"]
|
|
|
|
# 确定最后一站的到达时间和站点
|
|
if e["index"] == max(indexes):
|
|
train.arrival_time = e["arrTime"]
|
|
train.arrival_station = e["name"]
|
|
|
|
return train |