You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
5.8 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
from tqdm import tqdm
import warnings
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model
import collections
import pickle
from libreco.algorithms import PinSage
from libreco.data import DatasetFeat
from config import logger, offline_mode
from config import need_metric_recall
from config import shixun_pinsage_model_path
from matching.shixun.recall_comm import get_all_select_df
from matching.shixun.recall_comm import get_user_info_df,get_item_info_df
from matching.shixun.recall_comm import metrics_pinsage_recall,get_hist_and_last_select,metrics_recall
from config import shixun_pinsage_recall_dict
K.set_learning_phase(True)
if tf.__version__ >= '2.0.0':
tf.compat.v1.disable_eager_execution()
tqdm.pandas()
warnings.filterwarnings('ignore')
def reset_state(name):
tf.compat.v1.reset_default_graph()
print("\n", "=" * 30, name, "=" * 30)
def pinsage_recall_train():
"""
pinsage召回训练
"""
# 需要召回的数量
recall_item_num = 100
logger.info("加载物品行为数据")
all_select_df = get_all_select_df(offline=offline_mode)
logger.info("获取物品信息数据")
item_info = get_item_info_df()
logger.info("获取用户信息数据")
users_info = get_user_info_df()
all_select_df = all_select_df.merge(users_info, on='user_id')
all_select_df = all_select_df.merge(item_info,on='shixun_id')
# 为了召回评估,提取最后一次选择作为召回评估
# 如果不需要做召回评估直接使用全量的训练集进行召回
if need_metric_recall:
logger.info('获取物品行为数据历史和最后一次选择')
train_hist_select_df, train_last_select_df = get_hist_and_last_select(all_select_df)
train_hist_select_df['label'] = 1
train_last_select_df['label'] = 1
train_hist_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
train_last_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
else:
train_hist_select_df = all_select_df
train_hist_select_df['label'] = 1
train_hist_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
train_last_select_df = all_select_df.sample(frac=0.001)
train_last_select_df['label'] = 1
train_last_select_df.rename(columns={'user_id': 'user','shixun_id': 'item'}, inplace=True)
#调试程序简单采样,注意删除临时文件
# train_hist_select_df = train_hist_select_df.sample(frac = 0.001)
# print(train_hist_select_df.head())
# 定义特征,指定完整列信息
# sparse_col = ['user_id','shixun_id']
dense_col = ['gender', 'identity', 'edu_background','logins','grade','experience',
'visits','challenges_count','averge_star','task_pass']
user_col = ['gender', 'identity', 'edu_background','logins','grade','experience']
item_col = ['visits','challenges_count','averge_star','task_pass']
train_data = train_hist_select_df[['user', 'item','gender', 'identity', 'edu_background','logins','grade','experience'
,'visits','challenges_count','averge_star','task_pass','label']]
eval_data = train_last_select_df[['user', 'item','gender', 'identity', 'edu_background','logins','grade','experience'
,'visits','challenges_count','averge_star','task_pass','label']]
# print(type(train_data))
# print(type(eval_data))
train_data, data_info = DatasetFeat.build_trainset(
train_data, user_col, item_col, dense_col,seed=2023,shuffle=False
)
eval_data = DatasetFeat.build_testset(eval_data)
reset_state("PinSage")
pinsage = PinSage(
"ranking",
data_info,
loss_type="cross_entropy",
paradigm="u2i",
embed_size=32,
n_epochs=3,
lr=3e-4,
lr_decay=False,
reg=None,
batch_size=256,
num_neg=1,
dropout_rate=0.01,
remove_edges=False,
num_layers=1,
num_neighbors=10,
num_walks=10,
neighbor_walk_len=2,
sample_walk_len=5,
termination_prob=0.5,
margin=1.0,
sampler="random",
start_node="random",
focus_start=False,
seed=2023,
)
pinsage.fit(
train_data,
neg_sampling=True,
verbose=2,
shuffle=True
)
# save data_info, 指定模型保存文件夹
data_info.save(path=shixun_pinsage_model_path, model_name="pinsage_model")
# 设置 manual=True 使用 `numpy` 保存模型
# 设置 manual=False 使用 `tf.train.Saver` 保存模型
# 设置 inference=True 将只保留预测和推荐所需的变量
pinsage.save(
path=shixun_pinsage_model_path, model_name="pinsage_model", manual=True, inference_only=True
)
print("训练结束,模型已保存")
user_recall_items_dict = collections.defaultdict(dict)
logger.info('生成item所有用户的召回列表有得分')
for user_id in tqdm(train_hist_select_df['user'].unique()):
item_list = list(pinsage.recommend_user(user=user_id, n_rec=recall_item_num).values())[0].tolist()
score = []
for i in item_list:
score.append(pinsage.predict(user=user_id, item=i)[0])
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
logger.info('保存pinsage召回结果')
pickle.dump(user_recall_items_dict, open(shixun_pinsage_recall_dict, 'wb'))
logger.info('pinsage召回效果评估')
metrics_pinsage_recall(user_recall_items_dict, train_last_select_df, topk=recall_item_num)
if __name__ == '__main__':
print("召回开始训练")
pinsage_recall_train()