You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.9 KiB
54 lines
1.9 KiB
import os
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
import tensorflow as tf
|
|
from libreco.algorithms import PinSage
|
|
from libreco.data import DataInfo
|
|
from libreco.evaluation import evaluate
|
|
import collections
|
|
from datetime import datetime
|
|
from collections import Counter
|
|
from config import shixun_pinsage_model_path
|
|
from config import logger,test_user_id
|
|
from config import shixun_pinsage_model_path
|
|
from matching.shixun.recall_comm import get_all_select_df
|
|
from matching.shixun.recall_comm import get_user_info_df,get_item_info_df
|
|
from matching.shixun.recall_comm import metrics_pinsage_recall,get_hist_and_last_select,metrics_recall
|
|
|
|
def load_mmodel(model_path,model_name):
|
|
|
|
tf.compat.v1.reset_default_graph()
|
|
data_info = DataInfo.load(path=model_path, model_name=model_name)
|
|
|
|
model = PinSage.load(
|
|
path = model_path, model_name=model_name, data_info=data_info,manual=True
|
|
)
|
|
return model
|
|
|
|
def pinsage_recall(user_id, topk=20):
|
|
"""
|
|
通过pinsage推荐指定用户topk物品
|
|
"""
|
|
#加载预训练模型pinsage
|
|
pinsage = load_mmodel(model_path=shixun_pinsage_model_path,model_name="pinsage_model")
|
|
start_time = datetime.now()
|
|
logger.info(f"本次需要进行pinsage召回的用户ID: {user_id}")
|
|
user_recall_items_dict = collections.defaultdict(dict)
|
|
|
|
item_list = list(pinsage.recommend_user(user=user_id, n_rec=topk).values())[0].tolist()
|
|
score = []
|
|
for i in item_list:
|
|
score.append(pinsage.predict(user=user_id, item=i)[0])
|
|
user_recall_items_dict[user_id] = list(zip(item_list,tuple(score)))
|
|
# 计算耗时毫秒
|
|
end_time = datetime.utcnow()
|
|
cost_time_millisecond = round(float((end_time - start_time).microseconds / 1000.0), 3)
|
|
|
|
logger.info(f"本次推荐总耗时: {cost_time_millisecond} 毫秒")
|
|
return user_recall_items_dict[user_id]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
recall_results = pinsage_recall(user_id=test_user_id, topk=100)
|
|
print(recall_results)
|