You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
5.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import os
from config import samples_mode
from config import samples_mode_flag
from config import logger
def del_output_file(parent_path):
"""
删除之前的数据和输出
"""
files_list = os.listdir(parent_path)
for file_name in files_list:
if not os.path.isdir(parent_path + file_name) and ('.py' not in file_name):
if os.path.exists(parent_path + file_name):
logger.info('删除文件: ' + parent_path + file_name)
os.system('rm ' + parent_path + file_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--silent', action='store_true',
help='nohup python start_all_train.py --silent >./logs/all_model_train.log &')
args = parser.parse_args()
silent_mode = args.silent
if samples_mode_flag == 'full':
mode_name = '全量数据'
else:
mode_name = '增量数据'
if silent_mode:
confirm = 'y'
else:
confirm = input('确认要开始重新训练' + mode_name + '的模型吗y/n' + '\n').strip()
if confirm == 'y' or confirm == 'Y':
if silent_mode:
confirm = 'y'
else:
confirm = input('是否删除之前的数据和所有输出y/n' + '\n').strip()
if confirm == 'y' or confirm == 'Y':
parent_path_list = []
parent_path_list.clear()
parent_path_list.append('./data/' + samples_mode_flag + '/')
parent_path_list.append('./data/cold_start_data/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./data/cold_start_data/subject/' + samples_mode_flag + '/')
parent_path_list.append('./results/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./results/subject/' + samples_mode_flag + '/')
parent_path_list.append('./features/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./features/subject/' + samples_mode_flag + '/')
parent_path_list.append('./models/shixun/' + samples_mode_flag + '/')
parent_path_list.append('./models/subject/' + samples_mode_flag + '/')
for parent_path in parent_path_list:
if not os.path.exists(parent_path):
os.mkdir(parent_path)
del_output_file(parent_path)
# 训练召回模型
os.system('python ./data_process.py')
os.system('python ./matching/shixun/build_keywords.py')
os.system('python ./matching/subject/build_keywords.py')
os.system('python ./matching/shixun/faiss_word2vec.py')
os.system('python ./matching/subject/faiss_word2vec.py')
os.system('python ./matching/shixun/hnsw_faiss.py')
os.system('python ./matching/subject/hnsw_faiss.py')
os.system('python ./matching/shixun/item_embedding.py')
os.system('python ./matching/subject/item_embedding.py')
os.system('python ./matching/shixun/item_merge_emb.py')
os.system('python ./matching/subject/item_merge_emb.py')
os.system('python ./matching/shixun/Item2Vec.py')
os.system('python ./matching/subject/Item2Vec.py')
os.system('python ./matching/shixun/cold_start_recall.py')
os.system('python ./matching/subject/cold_start_recall.py')
os.system('python ./matching/shixun/item_embedding_recall.py')
os.system('python ./matching/subject/item_embedding_recall.py')
os.system('python ./matching/shixun/itemcf_recall.py')
os.system('python ./matching/subject/itemcf_recall.py')
os.system('python ./matching/shixun/youtubednn_recall_train.py')
os.system('python ./matching/subject/youtubednn_recall_train.py')
os.system('python ./matching/shixun/youtube_usercf_recall.py')
os.system('python ./matching/subject/youtube_usercf_recall.py')
os.system('python ./matching/shixun/dssm_recall_train.py')
os.system('python ./matching/subject/dssm_recall_train.py')
os.system('python ./matching/shixun/dssm_usercf_recall.py')
os.system('python ./matching/subject/dssm_usercf_recall.py')
os.system('python ./matching/shixun/fm_recall_train.py')
os.system('python ./matching/subject/fm_recall_train.py')
os.system('python ./matching/shixun/mind_recall_train.py')
os.system('python ./matching/subject/mind_recall_train.py')
os.system('python ./matching/shixun/pinsage_recall_train.py')
os.system('python ./matching/subject/pinsage_recall_train.py')
# 全量数据计算每路离线召回结果耗时太长
# 只在增量数据召回时合并多路召回的结果
if samples_mode:
os.system('python ./matching/shixun/multi_recall_combine.py')
os.system('python ./matching/subject/multi_recall_combine.py')
# 排序特征工程
os.system('python ./ranking/shixun/bert_embedding.py')
os.system('python ./ranking/subject/bert_embedding.py')
os.system('python ./ranking/shixun/rank_features_engineering.py')
os.system('python ./ranking/subject/rank_features_engineering.py')
# 训练排序模型
os.system('python ./ranking/shixun/xdeepfm_ranker_train.py')
os.system('python ./ranking/subject/xdeepfm_ranker_train.py')
os.system('python ./ranking/shixun/difm_ranker_train.py')
os.system('python ./ranking/subject/difm_ranker_train.py')
os.system('python ./ranking/shixun/bst_ranker_train.py')
os.system('python ./ranking/subject/bst_ranker_train.py')