You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

47 lines
1.6 KiB

import os
import sys
sys.path.append(os.getcwd())
import tensorflow as tf
from libreco.algorithms import PinSage
from libreco.data import DataInfo
import collections
from datetime import datetime
from config import subject_pinsage_model_path
from config import logger,test_user_id
def load_mmodel(model_path,model_name):
tf.compat.v1.reset_default_graph()
# load data_info
data_info = DataInfo.load(path=model_path, model_name=model_name)
model = PinSage.load(
path=model_path, model_name=model_name, data_info=data_info,manual=True
)
return model
def pinsage_recall(user_id, topk=20):
"""
通过pinsage推荐指定用户topk物品
"""
#加载预训练模型pinsage
pinsage = load_mmodel(model_path=subject_pinsage_model_path,model_name="pinsage_model")
start_time = datetime.now()
logger.info(f"本次需要进行pinsage召回的用户ID: {user_id}")
user_recall_items_dict = collections.defaultdict(dict)
item_list = list(pinsage.recommend_user(user=user_id, n_rec=topk).values())[0].tolist()
score = []
for i in item_list:
score.append(pinsage.predict(user=user_id, item=i)[0])
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return user_recall_items_dict[user_id]
if __name__ == '__main__':
recall_results = pinsage_recall(user_id=test_user_id, topk=20)
print(recall_results)