You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.9 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import tensorflow as tf
from libreco.algorithms import PinSage
from libreco.data import DataInfo
from libreco.evaluation import evaluate
import collections
from datetime import datetime
from collections import Counter
from config import shixun_pinsage_model_path
from config import logger,test_user_id
from config import shixun_pinsage_model_path
from matching.shixun.recall_comm import get_all_select_df
from matching.shixun.recall_comm import get_user_info_df,get_item_info_df
from matching.shixun.recall_comm import metrics_pinsage_recall,get_hist_and_last_select,metrics_recall
def load_mmodel(model_path,model_name):
tf.compat.v1.reset_default_graph()
data_info = DataInfo.load(path=model_path, model_name=model_name)
model = PinSage.load(
path = model_path, model_name=model_name, data_info=data_info,manual=True
)
return model
def pinsage_recall(user_id, topk=20):
"""
通过pinsage推荐指定用户topk物品
"""
#加载预训练模型pinsage
pinsage = load_mmodel(model_path=shixun_pinsage_model_path,model_name="pinsage_model")
start_time = datetime.now()
logger.info(f"本次需要进行pinsage召回的用户ID: {user_id}")
user_recall_items_dict = collections.defaultdict(dict)
item_list = list(pinsage.recommend_user(user=user_id, n_rec=topk).values())[0].tolist()
score = []
for i in item_list:
score.append(pinsage.predict(user=user_id, item=i)[0])
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
# 计算耗时毫秒
end_time = datetime.utcnow()
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
return user_recall_items_dict[user_id]
if __name__ == '__main__':
recall_results = pinsage_recall(user_id=test_user_id, topk=100)
print(recall_results)