You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

47 lines
1.6 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import tensorflow as tf
from libreco.algorithms import PinSage
from libreco.data import DataInfo
import collections
from datetime import datetime
from config import subject_pinsage_model_path
from config import logger,test_user_id
def load_mmodel(model_path,model_name):
tf.compat.v1.reset_default_graph()
# load data_info
data_info = DataInfo.load(path=model_path, model_name=model_name)
model = PinSage.load(
path=model_path, model_name=model_name, data_info=data_info,manual=True
)
return model
def pinsage_recall(user_id, topk=20):
"""
通过pinsage推荐指定用户topk物品
"""
#加载预训练模型pinsage
pinsage = load_mmodel(model_path=subject_pinsage_model_path,model_name="pinsage_model")
start_time = datetime.now()
logger.info(f"本次需要进行pinsage召回的用户ID: {user_id}")
user_recall_items_dict = collections.defaultdict(dict)
item_list = list(pinsage.recommend_user(user=user_id, n_rec=topk).values())[0].tolist()
score = []
for i in item_list:
score.append(pinsage.predict(user=user_id, item=i)[0])
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return user_recall_items_dict[user_id]
if __name__ == '__main__':
recall_results = pinsage_recall(user_id=test_user_id, topk=20)
print(recall_results)