You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.3 KiB
92 lines
3.3 KiB
from sqlalchemy import func
|
|
|
|
from app.models import db
|
|
from app.models.train_station_lib import TrainStation
|
|
|
|
class Train(db.Model):
|
|
__tablename__: str = 'train'
|
|
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
train_no = db.Column(db.String(120), unique=True, nullable=False, index=True)
|
|
departure_station = db.Column(db.String(120))
|
|
arrival_station = db.Column(db.String(120))
|
|
departure_time = db.Column(db.DateTime)
|
|
expiration_time = db.Column(db.DateTime)
|
|
effective_time = db.Column(db.DateTime)
|
|
arrival_time = db.Column(db.DateTime)
|
|
created_at = db.Column(db.DateTime, default=func.now())
|
|
updated_at = db.Column(db.DateTime, default=func.now())
|
|
|
|
def __repr__(self):
|
|
return f'<Train {self.id}>'
|
|
|
|
@classmethod
|
|
def create(cls, new_train):
|
|
db.session.add(new_train)
|
|
db.session.commit()
|
|
return new_train
|
|
|
|
@classmethod
|
|
def queryTrains(cls, from_station, to_station, date):
|
|
# Query for train stations where the station name matches `from_station`
|
|
from_train = TrainStation.query.filter_by(station_name=from_station).all()
|
|
|
|
# Query for train stations where the station name matches `to_station`
|
|
to_train = TrainStation.query.filter_by(station_name=to_station).all()
|
|
|
|
# Extract train_no from both query results
|
|
from_train_nos = {ts.train_no for ts in from_train}
|
|
to_train_nos = {ts.train_no for ts in to_train}
|
|
|
|
# Find the common train_no between the two stations
|
|
common_train_nos = from_train_nos & to_train_nos
|
|
|
|
# Filter train numbers where the index of the from station is less than the index of the to station
|
|
valid_train_nos = [
|
|
train_no for train_no in common_train_nos
|
|
if next(ts.index for ts in from_train if ts.train_no == train_no) <
|
|
next(ts.index for ts in to_train if ts.train_no == train_no)
|
|
]
|
|
# Query trains by the filtered train numbers and the given date (assuming date filtering)
|
|
trains = Train.query.filter(
|
|
Train.effective_time >= date,
|
|
Train.train_no.in_(valid_train_nos)
|
|
).all()
|
|
|
|
# Assuming you have a presenter or serializer for the trains
|
|
return trains
|
|
|
|
|
|
def buildTrain(params):
|
|
# Create a new Train object
|
|
train = Train(
|
|
train_no=params['trainNo'],
|
|
effective_time=params['effective_time'],
|
|
expiration_time=params['expiration_time']
|
|
)
|
|
|
|
# Extract indexes for determining first and last station
|
|
indexes = [e["index"] for e in params['stations']]
|
|
|
|
for e in params['stations']:
|
|
# Create and associate TrainStation objects
|
|
train_station = TrainStation(
|
|
station_name=e["name"],
|
|
price=e["price"],
|
|
departure_time=e["depTime"],
|
|
arrival_time=e["arrTime"],
|
|
index=e["index"]
|
|
)
|
|
train.train_stations.append(train_station)
|
|
# Determine the departure time and station for the first station
|
|
if e["index"] == 0:
|
|
train.departure_time = e["depTime"]
|
|
train.departure_station = e["name"]
|
|
|
|
# Determine the arrival time and station for the last station
|
|
if e["index"] == max(indexes):
|
|
train.arrival_time = e["arrTime"]
|
|
train.arrival_station = e["name"]
|
|
|
|
return train
|