You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
6.6 KiB
168 lines
6.6 KiB
import os
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
from tqdm import tqdm
|
|
import warnings
|
|
import tensorflow as tf
|
|
from tensorflow.python.keras import backend as K
|
|
from tensorflow.python.keras.models import Model
|
|
import collections
|
|
import pickle
|
|
from libreco.algorithms import PinSage
|
|
from libreco.data import DatasetFeat
|
|
from config import logger, offline_mode
|
|
from config import need_metric_recall
|
|
from config import subject_pinsage_model_path
|
|
from config import subject_pinsage_recall_dict
|
|
from matching.subject.recall_comm import get_all_select_df
|
|
from matching.subject.recall_comm import get_user_info_df,get_item_info_df
|
|
from matching.subject.recall_comm import metrics_pinsage_recall,get_all_hist_and_last_select
|
|
|
|
K.set_learning_phase(True)
|
|
if tf.__version__ >= '2.0.0':
|
|
tf.compat.v1.disable_eager_execution()
|
|
|
|
tqdm.pandas()
|
|
warnings.filterwarnings('ignore')
|
|
|
|
def reset_state(name):
|
|
tf.compat.v1.reset_default_graph()
|
|
print("\n", "=" * 30, name, "=" * 30)
|
|
|
|
|
|
def pinsage_recall_train():
|
|
"""
|
|
pinsage召回训练
|
|
"""
|
|
|
|
# 需要召回的数量
|
|
recall_item_num = 100
|
|
|
|
logger.info("加载物品行为数据")
|
|
all_select_df = get_all_select_df(offline=offline_mode)
|
|
|
|
logger.info("获取物品信息数据")
|
|
item_info = get_item_info_df()
|
|
|
|
logger.info("获取用户信息数据")
|
|
users_info = get_user_info_df()
|
|
|
|
all_select_df = all_select_df.merge(users_info, on='user_id')
|
|
all_select_df = all_select_df.merge(item_info,on='subject_id')
|
|
|
|
# 为了召回评估,提取最后一次选择作为召回评估
|
|
# 如果不需要做召回评估直接使用全量的训练集进行召回
|
|
if need_metric_recall:
|
|
logger.info('获取物品行为数据历史和最后一次选择')
|
|
train_hist_select_df, train_last_select_df = get_all_hist_and_last_select(all_select_df)
|
|
train_hist_select_df['label'] = 1
|
|
train_last_select_df['label'] = 1
|
|
# pinsage模型需要
|
|
train_hist_select_df.rename(columns={'user_id': 'user','subject_id': 'item'}, inplace=True)
|
|
train_last_select_df.rename(columns={'user_id': 'user','subject_id': 'item'}, inplace=True)
|
|
else:
|
|
train_hist_select_df = all_select_df
|
|
train_hist_select_df['label'] = 1
|
|
train_hist_select_df.rename(columns={'user_id': 'user','subject_id': 'item'}, inplace=True)
|
|
train_last_select_df = all_select_df.sample(frac=0.001)
|
|
train_last_select_df['label'] = 1
|
|
train_last_select_df.rename(columns={'user_id': 'user','subject_id': 'item'}, inplace=True)
|
|
|
|
#调试程序简单采样,注意删除临时文件
|
|
# train_hist_select_df= train_hist_select_df.sample(frac=0.001)
|
|
# print(train_hist_select_df.head())
|
|
|
|
# 定义特征,指定完整列信息
|
|
# sparse_col = ['user_id','subject_id']
|
|
dense_col = ['gender', 'identity', 'edu_background','logins','grade','experience',
|
|
'visits','stages_count','stage_shixuns_count','shixuns_count','study_count',
|
|
'course_study_count','passed_count','challenge_count','evaluate_count',
|
|
'study_pdf_attachment_count','averge_star']
|
|
user_col = ['gender', 'identity', 'edu_background','logins','grade','experience']
|
|
item_col = ['visits','stages_count','stage_shixuns_count','shixuns_count','study_count',
|
|
'course_study_count','passed_count','challenge_count','evaluate_count',
|
|
'study_pdf_attachment_count','averge_star']
|
|
train_data = train_hist_select_df[['user', 'item','gender', 'identity', 'edu_background','logins','grade','experience',
|
|
'visits','stages_count','stage_shixuns_count','shixuns_count','study_count',
|
|
'course_study_count','passed_count','challenge_count','evaluate_count',
|
|
'study_pdf_attachment_count','averge_star','label']]
|
|
eval_data = train_last_select_df[['user', 'item','gender', 'identity', 'edu_background','logins','grade','experience',
|
|
'visits','stages_count','stage_shixuns_count','shixuns_count','study_count',
|
|
'course_study_count','passed_count','challenge_count','evaluate_count',
|
|
'study_pdf_attachment_count','averge_star','label']]
|
|
# print(type(train_data))
|
|
# print(type(eval_data))
|
|
# input()
|
|
train_data, data_info = DatasetFeat.build_trainset(
|
|
train_data, user_col, item_col, dense_col,seed=2023,shuffle=False
|
|
)
|
|
eval_data = DatasetFeat.build_testset(eval_data)
|
|
|
|
reset_state("PinSage")
|
|
|
|
pinsage = PinSage(
|
|
"ranking",
|
|
data_info,
|
|
loss_type="cross_entropy",
|
|
paradigm="u2i",
|
|
embed_size=32,
|
|
n_epochs=10,
|
|
lr=3e-4,
|
|
lr_decay=False,
|
|
reg=None,
|
|
batch_size=256,
|
|
num_neg=1,
|
|
dropout_rate=0.01,
|
|
remove_edges=False,
|
|
num_layers=1,
|
|
num_neighbors=10,
|
|
num_walks=10,
|
|
neighbor_walk_len=2,
|
|
sample_walk_len=5,
|
|
termination_prob=0.5,
|
|
margin=1.0,
|
|
sampler="random",
|
|
start_node="random",
|
|
focus_start=False,
|
|
seed=2023,
|
|
)
|
|
pinsage.fit(
|
|
train_data,
|
|
neg_sampling=True,
|
|
verbose=2,
|
|
shuffle=True
|
|
|
|
)
|
|
|
|
# save data_info, 指定模型保存文件夹
|
|
data_info.save(path=subject_pinsage_model_path, model_name="pinsage_model")
|
|
|
|
# 设置 manual=True 使用 `numpy` 保存模型
|
|
# 设置 manual=False 使用 `tf.train.Saver` 保存模型
|
|
# 设置 inference=True 将只保留预测和推荐所需的变量
|
|
pinsage.save(
|
|
path=subject_pinsage_model_path, model_name="pinsage_model", manual=True, inference_only=True
|
|
)
|
|
print("训练结束,模型已保存")
|
|
user_recall_items_dict = collections.defaultdict(dict)
|
|
|
|
logger.info('生成item所有用户的召回列表有得分')
|
|
for user_id in tqdm(train_hist_select_df['user'].unique()):
|
|
item_list = list(pinsage.recommend_user(user=user_id, n_rec=recall_item_num).values())[0].tolist()
|
|
|
|
score = []
|
|
for i in item_list:
|
|
score.append(pinsage.predict(user=user_id, item=i)[0])
|
|
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
|
|
|
|
logger.info('保存pinsage召回结果')
|
|
pickle.dump(user_recall_items_dict, open(subject_pinsage_recall_dict, 'wb'))
|
|
|
|
logger.info('pinsage召回效果评估')
|
|
metrics_pinsage_recall(user_recall_items_dict, train_last_select_df, topk=recall_item_num)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print("召回开始训练")
|
|
pinsage_recall_train()
|
|
|