You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.3 KiB

# 从 sqlalchemy 导入 func 函数,用于在数据库查询中使用函数,如获取当前时间
from sqlalchemy import func
# 从 app.models 模块导入数据库实例
from app.models import db
# 从 app.models.train_station_lib 模块导入 TrainStation 模型
from app.models.train_station_lib import TrainStation
class Train(db.Model):
"""
定义 Train 模型,对应数据库中的 'train' 表,用于存储列车信息。
"""
__tablename__: str = 'train'
# 定义列车 ID 列,作为主键,自增整数类型
id = db.Column(db.Integer, primary_key=True)
# 定义列车编号列,最大长度为 120 的字符串类型,唯一且不能为空,添加索引以提高查询效率
train_no = db.Column(db.String(120), unique=True, nullable=False, index=True)
# 定义出发站列,最大长度为 120 的字符串类型
departure_station = db.Column(db.String(120))
# 定义到达站列,最大长度为 120 的字符串类型
arrival_station = db.Column(db.String(120))
# 定义出发时间列,日期时间类型
departure_time = db.Column(db.DateTime)
# 定义过期时间列,日期时间类型
expiration_time = db.Column(db.DateTime)
# 定义生效时间列,日期时间类型
effective_time = db.Column(db.DateTime)
# 定义到达时间列,日期时间类型
arrival_time = db.Column(db.DateTime)
# 定义列车记录创建时间列,日期时间类型,默认值为当前时间
created_at = db.Column(db.DateTime, default=func.now())
# 定义列车记录更新时间列,日期时间类型,默认值为当前时间
updated_at = db.Column(db.DateTime, default=func.now())
def __repr__(self):
"""
返回一个可打印的字符串表示该列车对象,方便调试和日志记录。
:return: 包含列车 ID 的字符串
"""
return f'<Train {self.id}>'
@classmethod
def create(cls, new_train):
"""
将新的列车对象添加到数据库并提交更改。
:param new_train: 要创建的列车对象
:return: 创建成功的列车对象
"""
# 将新列车对象添加到数据库会话
db.session.add(new_train)
# 提交数据库会话,将更改保存到数据库
db.session.commit()
return new_train
@classmethod
def queryTrains(cls, from_station, to_station, date):
"""
根据出发站、到达站和日期查询符合条件的列车。
:param from_station: 出发站名称
:param to_station: 到达站名称
:param date: 查询日期
:return: 符合条件的列车列表
"""
# 查询出发站名称匹配 `from_station` 的所有列车站点记录
from_train = TrainStation.query.filter_by(station_name=from_station).all()
# 查询到达站名称匹配 `to_station` 的所有列车站点记录
to_train = TrainStation.query.filter_by(station_name=to_station).all()
# 从出发站查询结果中提取列车编号,存储在集合中以确保唯一性
from_train_nos = {ts.train_no for ts in from_train}
# 从到达站查询结果中提取列车编号,存储在集合中以确保唯一性
to_train_nos = {ts.train_no for ts in to_train}
# 找出两个集合中共同的列车编号
common_train_nos = from_train_nos & to_train_nos
# 过滤出出发站索引小于到达站索引的列车编号
valid_train_nos = [
train_no for train_no in common_train_nos
if next(ts.index for ts in from_train if ts.train_no == train_no) <
next(ts.index for ts in to_train if ts.train_no == train_no)
]
# 根据过滤后的列车编号和给定日期查询列车信息
trains = Train.query.filter(
Train.effective_time >= date,
Train.train_no.in_(valid_train_nos)
).all()
# 假设存在一个用于处理列车数据的展示器或序列化器
return trains
def buildTrain(params):
"""
根据传入的参数构建一个新的列车对象,并关联相应的列车站点信息。
:param params: 包含列车和站点信息的字典
:return: 构建好的列车对象
"""
# 创建一个新的 Train 对象
train = Train(
train_no=params['trainNo'],
effective_time=params['effective_time'],
expiration_time=params['expiration_time']
)
# 从站点信息中提取索引值
indexes = [e["index"] for e in params['stations']]
for e in params['stations']:
# 创建并关联 TrainStation 对象
train_station = TrainStation(
station_name=e["name"],
price=e["price"],
departure_time=e["depTime"],
arrival_time=e["arrTime"],
index=e["index"]
)
# 将列车站点对象添加到列车的站点列表中
train.train_stations.append(train_station)
# 确定第一站的出发时间和站点
if e["index"] == 0:
train.departure_time = e["depTime"]
train.departure_station = e["name"]
# 确定最后一站的到达时间和站点
if e["index"] == max(indexes):
train.arrival_time = e["arrTime"]
train.arrival_station = e["name"]
return train