You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
5.5 KiB

5 months ago
import argparse
import os
from config import samples_mode
from config import samples_mode_flag
from config import logger
def del_output_file(parent_path):
"""
删除之前的数据和输出
"""
files_list = os.listdir(parent_path)
for file_name in files_list:
if not os.path.isdir(parent_path + file_name) and ('.py' not in file_name):
if os.path.exists(parent_path + file_name):
logger.info('删除文件: ' + parent_path + file_name)
os.system('rm ' + parent_path + file_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--silent', action='store_true',
help='nohup python start_all_train.py --silent >./logs/all_model_train.log &')
args = parser.parse_args()
silent_mode = args.silent
if samples_mode_flag == 'full':
mode_name = '全量数据'
else:
mode_name = '增量数据'
if silent_mode:
confirm = 'y'
else:
confirm = input('确认要开始重新训练' + mode_name + '的模型吗y/n' + '\n').strip()
if confirm == 'y' or confirm == 'Y':
if silent_mode:
confirm = 'y'
else:
confirm = input('是否删除之前的数据和所有输出y/n' + '\n').strip()
if confirm == 'y' or confirm == 'Y':
parent_path_list = []
parent_path_list.clear()
parent_path_list.append('./data/' + samples_mode_flag + '/')
parent_path_list.append('./data/cold_start_data/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./data/cold_start_data/subject/' + samples_mode_flag + '/')
parent_path_list.append('./results/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./results/subject/' + samples_mode_flag + '/')
parent_path_list.append('./features/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./features/subject/' + samples_mode_flag + '/')
parent_path_list.append('./models/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./models/subject/' + samples_mode_flag + '/')
for parent_path in parent_path_list:
if not os.path.exists(parent_path):
os.mkdir(parent_path)
del_output_file(parent_path)
# 训练召回模型
os.system('python ./data_process.py')
os.system('python ./matching/shixun/build_keywords.py')
os.system('python ./matching/subject/build_keywords.py')
os.system('python ./matching/shixun/faiss_word2vec.py')
os.system('python ./matching/subject/faiss_word2vec.py')
os.system('python ./matching/shixun/hnsw_faiss.py')
os.system('python ./matching/subject/hnsw_faiss.py')
os.system('python ./matching/shixun/item_embedding.py')
os.system('python ./matching/subject/item_embedding.py')
os.system('python ./matching/shixun/item_merge_emb.py')
os.system('python ./matching/subject/item_merge_emb.py')
os.system('python ./matching/shixun/Item2Vec.py')
os.system('python ./matching/subject/Item2Vec.py')
os.system('python ./matching/shixun/cold_start_recall.py')
os.system('python ./matching/subject/cold_start_recall.py')
os.system('python ./matching/shixun/item_embedding_recall.py')
os.system('python ./matching/subject/item_embedding_recall.py')
os.system('python ./matching/shixun/itemcf_recall.py')
os.system('python ./matching/subject/itemcf_recall.py')
os.system('python ./matching/shixun/youtubednn_recall_train.py')
os.system('python ./matching/subject/youtubednn_recall_train.py')
os.system('python ./matching/shixun/youtube_usercf_recall.py')
os.system('python ./matching/subject/youtube_usercf_recall.py')
os.system('python ./matching/shixun/dssm_recall_train.py')
os.system('python ./matching/subject/dssm_recall_train.py')
os.system('python ./matching/shixun/dssm_usercf_recall.py')
os.system('python ./matching/subject/dssm_usercf_recall.py')
os.system('python ./matching/shixun/fm_recall_train.py')
os.system('python ./matching/subject/fm_recall_train.py')
os.system('python ./matching/shixun/mind_recall_train.py')
os.system('python ./matching/subject/mind_recall_train.py')
os.system('python ./matching/shixun/pinsage_recall_train.py')
os.system('python ./matching/subject/pinsage_recall_train.py')
# 全量数据计算每路离线召回结果耗时太长
# 只在增量数据召回时合并多路召回的结果
if samples_mode:
os.system('python ./matching/shixun/multi_recall_combine.py')
os.system('python ./matching/subject/multi_recall_combine.py')
# 排序特征工程
os.system('python ./ranking/shixun/bert_embedding.py')
os.system('python ./ranking/subject/bert_embedding.py')
os.system('python ./ranking/shixun/rank_features_engineering.py')
os.system('python ./ranking/subject/rank_features_engineering.py')
# 训练排序模型
os.system('python ./ranking/shixun/xdeepfm_ranker_train.py')
os.system('python ./ranking/subject/xdeepfm_ranker_train.py')
os.system('python ./ranking/shixun/difm_ranker_train.py')
os.system('python ./ranking/subject/difm_ranker_train.py')
os.system('python ./ranking/shixun/bst_ranker_train.py')
os.system('python ./ranking/subject/bst_ranker_train.py')