You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.6 KiB
47 lines
1.6 KiB
5 months ago
|
import os
|
||
|
import sys
|
||
|
sys.path.append(os.getcwd())
|
||
|
import tensorflow as tf
|
||
|
from libreco.algorithms import PinSage
|
||
|
from libreco.data import DataInfo
|
||
|
import collections
|
||
|
from datetime import datetime
|
||
|
from config import subject_pinsage_model_path
|
||
|
from config import logger,test_user_id
|
||
|
|
||
|
def load_mmodel(model_path,model_name):
|
||
|
|
||
|
tf.compat.v1.reset_default_graph()
|
||
|
# load data_info
|
||
|
data_info = DataInfo.load(path=model_path, model_name=model_name)
|
||
|
model = PinSage.load(
|
||
|
path=model_path, model_name=model_name, data_info=data_info,manual=True
|
||
|
)
|
||
|
return model
|
||
|
|
||
|
def pinsage_recall(user_id, topk=20):
|
||
|
"""
|
||
|
通过pinsage推荐指定用户topk物品
|
||
|
"""
|
||
|
#加载预训练模型pinsage
|
||
|
pinsage = load_mmodel(model_path=subject_pinsage_model_path,model_name="pinsage_model")
|
||
|
start_time = datetime.now()
|
||
|
logger.info(f"本次需要进行pinsage召回的用户ID: {user_id}")
|
||
|
user_recall_items_dict = collections.defaultdict(dict)
|
||
|
|
||
|
item_list = list(pinsage.recommend_user(user=user_id, n_rec=topk).values())[0].tolist()
|
||
|
score = []
|
||
|
for i in item_list:
|
||
|
score.append(pinsage.predict(user=user_id, item=i)[0])
|
||
|
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
|
||
|
# 计算耗时毫秒
|
||
|
end_time = datetime.utcnow()
|
||
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
||
|
|
||
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
||
|
return user_recall_items_dict[user_id]
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
recall_results = pinsage_recall(user_id=test_user_id, topk=20)
|
||
|
print(recall_results)
|