You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
5.8 KiB

import os
import sys
sys.path.append(os.getcwd())
from tqdm import tqdm
import warnings
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model
import collections
import pickle
from libreco.algorithms import PinSage
from libreco.data import DatasetFeat
from config import logger, offline_mode
from config import need_metric_recall
from config import shixun_pinsage_model_path
from matching.shixun.recall_comm import get_all_select_df
from matching.shixun.recall_comm import get_user_info_df,get_item_info_df
from matching.shixun.recall_comm import metrics_pinsage_recall,get_hist_and_last_select,metrics_recall
from config import shixun_pinsage_recall_dict
K.set_learning_phase(True)
if tf.__version__ >= '2.0.0':
tf.compat.v1.disable_eager_execution()
tqdm.pandas()
warnings.filterwarnings('ignore')
def reset_state(name):
tf.compat.v1.reset_default_graph()
print("\n", "=" * 30, name, "=" * 30)
def pinsage_recall_train():
"""
pinsage召回训练
"""
# 需要召回的数量
recall_item_num = 100
logger.info("加载物品行为数据")
all_select_df = get_all_select_df(offline=offline_mode)
logger.info("获取物品信息数据")
item_info = get_item_info_df()
logger.info("获取用户信息数据")
users_info = get_user_info_df()
all_select_df = all_select_df.merge(users_info, on='user_id')
all_select_df = all_select_df.merge(item_info,on='shixun_id')
# 为了召回评估,提取最后一次选择作为召回评估
# 如果不需要做召回评估直接使用全量的训练集进行召回
if need_metric_recall:
logger.info('获取物品行为数据历史和最后一次选择')
train_hist_select_df, train_last_select_df = get_hist_and_last_select(all_select_df)
train_hist_select_df['label'] = 1
train_last_select_df['label'] = 1
train_hist_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
train_last_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
else:
train_hist_select_df = all_select_df
train_hist_select_df['label'] = 1
train_hist_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
train_last_select_df = all_select_df.sample(frac=0.001)
train_last_select_df['label'] = 1
train_last_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
#调试程序简单采样,注意删除临时文件
# train_hist_select_df = train_hist_select_df.sample(frac = 0.001)
# print(train_hist_select_df.head())
# 定义特征,指定完整列信息
# sparse_col = ['user_id','shixun_id']
dense_col = ['gender', 'identity', 'edu_background','logins','grade','experience',
'visits','challenges_count','averge_star','task_pass']
user_col = ['gender', 'identity', 'edu_background','logins','grade','experience']
item_col = ['visits','challenges_count','averge_star','task_pass']
train_data = train_hist_select_df[['user', 'item','gender', 'identity', 'edu_background','logins','grade','experience'
,'visits','challenges_count','averge_star','task_pass','label']]
eval_data = train_last_select_df[['user', 'item','gender', 'identity', 'edu_background','logins','grade','experience'
,'visits','challenges_count','averge_star','task_pass','label']]
# print(type(train_data))
# print(type(eval_data))
train_data, data_info = DatasetFeat.build_trainset(
train_data, user_col, item_col, dense_col,seed=2023,shuffle=False
)
eval_data = DatasetFeat.build_testset(eval_data)
reset_state("PinSage")
pinsage = PinSage(
"ranking",
data_info,
loss_type="cross_entropy",
paradigm="u2i",
embed_size=32,
n_epochs=3,
lr=3e-4,
lr_decay=False,
reg=None,
batch_size=256,
num_neg=1,
dropout_rate=0.01,
remove_edges=False,
num_layers=1,
num_neighbors=10,
num_walks=10,
neighbor_walk_len=2,
sample_walk_len=5,
termination_prob=0.5,
margin=1.0,
sampler="random",
start_node="random",
focus_start=False,
seed=2023,
)
pinsage.fit(
train_data,
neg_sampling=True,
verbose=2,
shuffle=True
)
# save data_info, 指定模型保存文件夹
data_info.save(path=shixun_pinsage_model_path, model_name="pinsage_model")
# 设置 manual=True 使用 `numpy` 保存模型
# 设置 manual=False 使用 `tf.train.Saver` 保存模型
# 设置 inference=True 将只保留预测和推荐所需的变量
pinsage.save(
path=shixun_pinsage_model_path, model_name="pinsage_model", manual=True, inference_only=True
)
print("训练结束,模型已保存")
user_recall_items_dict = collections.defaultdict(dict)
logger.info('生成item所有用户的召回列表有得分')
for user_id in tqdm(train_hist_select_df['user'].unique()):
item_list = list(pinsage.recommend_user(user=user_id, n_rec=recall_item_num).values())[0].tolist()
score = []
for i in item_list:
score.append(pinsage.predict(user=user_id, item=i)[0])
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
logger.info('保存pinsage召回结果')
pickle.dump(user_recall_items_dict, open(shixun_pinsage_recall_dict, 'wb'))
logger.info('pinsage召回效果评估')
metrics_pinsage_recall(user_recall_items_dict, train_last_select_df, topk=recall_item_num)
if __name__ == '__main__':
print("召回开始训练")
pinsage_recall_train()