import argparse import os from config import samples_mode from config import samples_mode_flag from config import logger def del_output_file(parent_path): """ 删除之前的数据和输出 """ files_list = os.listdir(parent_path) for file_name in files_list: if not os.path.isdir(parent_path + file_name) and ('.py' not in file_name): if os.path.exists(parent_path + file_name): logger.info('删除文件: ' + parent_path + file_name) os.system('rm ' + parent_path + file_name) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--silent', action='store_true', help='nohup python start_all_train.py --silent >./logs/all_model_train.log &') args = parser.parse_args() silent_mode = args.silent if samples_mode_flag == 'full': mode_name = '全量数据' else: mode_name = '增量数据' if silent_mode: confirm = 'y' else: confirm = input('确认要开始重新训练' + mode_name + '的模型吗?y/n' + '\n').strip() if confirm == 'y' or confirm == 'Y': if silent_mode: confirm = 'y' else: confirm = input('是否删除之前的数据和所有输出?y/n' + '\n').strip() if confirm == 'y' or confirm == 'Y': parent_path_list = [] parent_path_list.clear() parent_path_list.append('./data/' + samples_mode_flag + '/') parent_path_list.append('./data/cold_start_data/shixun/' + samples_mode_flag + '/') parent_path_list.append('./data/cold_start_data/subject/' + samples_mode_flag + '/') parent_path_list.append('./results/shixun/' + samples_mode_flag + '/') parent_path_list.append('./results/subject/' + samples_mode_flag + '/') parent_path_list.append('./features/shixun/' + samples_mode_flag + '/') parent_path_list.append('./features/subject/' + samples_mode_flag + '/') parent_path_list.append('./models/shixun/' + samples_mode_flag + '/') parent_path_list.append('./models/subject/' + samples_mode_flag + '/') for parent_path in parent_path_list: if not os.path.exists(parent_path): os.mkdir(parent_path) del_output_file(parent_path) # 训练召回模型 os.system('python ./data_process.py') os.system('python ./matching/shixun/build_keywords.py') os.system('python ./matching/subject/build_keywords.py') os.system('python ./matching/shixun/faiss_word2vec.py') os.system('python ./matching/subject/faiss_word2vec.py') os.system('python ./matching/shixun/hnsw_faiss.py') os.system('python ./matching/subject/hnsw_faiss.py') os.system('python ./matching/shixun/item_embedding.py') os.system('python ./matching/subject/item_embedding.py') os.system('python ./matching/shixun/item_merge_emb.py') os.system('python ./matching/subject/item_merge_emb.py') os.system('python ./matching/shixun/Item2Vec.py') os.system('python ./matching/subject/Item2Vec.py') os.system('python ./matching/shixun/cold_start_recall.py') os.system('python ./matching/subject/cold_start_recall.py') os.system('python ./matching/shixun/item_embedding_recall.py') os.system('python ./matching/subject/item_embedding_recall.py') os.system('python ./matching/shixun/itemcf_recall.py') os.system('python ./matching/subject/itemcf_recall.py') os.system('python ./matching/shixun/youtubednn_recall_train.py') os.system('python ./matching/subject/youtubednn_recall_train.py') os.system('python ./matching/shixun/youtube_usercf_recall.py') os.system('python ./matching/subject/youtube_usercf_recall.py') os.system('python ./matching/shixun/dssm_recall_train.py') os.system('python ./matching/subject/dssm_recall_train.py') os.system('python ./matching/shixun/dssm_usercf_recall.py') os.system('python ./matching/subject/dssm_usercf_recall.py') os.system('python ./matching/shixun/fm_recall_train.py') os.system('python ./matching/subject/fm_recall_train.py') os.system('python ./matching/shixun/mind_recall_train.py') os.system('python ./matching/subject/mind_recall_train.py') os.system('python ./matching/shixun/pinsage_recall_train.py') os.system('python ./matching/subject/pinsage_recall_train.py') # 全量数据计算每路离线召回结果耗时太长 # 只在增量数据召回时合并多路召回的结果 if samples_mode: os.system('python ./matching/shixun/multi_recall_combine.py') os.system('python ./matching/subject/multi_recall_combine.py') # 排序特征工程 os.system('python ./ranking/shixun/bert_embedding.py') os.system('python ./ranking/subject/bert_embedding.py') os.system('python ./ranking/shixun/rank_features_engineering.py') os.system('python ./ranking/subject/rank_features_engineering.py') # 训练排序模型 os.system('python ./ranking/shixun/xdeepfm_ranker_train.py') os.system('python ./ranking/subject/xdeepfm_ranker_train.py') os.system('python ./ranking/shixun/difm_ranker_train.py') os.system('python ./ranking/subject/difm_ranker_train.py') os.system('python ./ranking/shixun/bst_ranker_train.py') os.system('python ./ranking/subject/bst_ranker_train.py')