You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
8.0 KiB

5 months ago
import pandas as pd
from tqdm import tqdm
import warnings
import pickle
import json
import random
import time
import os
import sys
sys.path.append(os.getcwd())
from config import logger
from config import test_user_id
from config import shixun_cold_start_recall_dict
from config import cold_start_shixuns_data_path
from config import shixuns_data_path
from config import shixun_cold_start_user_shixun_dict
from config import cold_start_shixuns_parent_path
from matching.shixun.recall_comm import get_all_select_df
from matching.shixun.recall_comm import get_item_info_df
from utils import is_number
# 用户冷启动召回推荐
tqdm.pandas()
warnings.filterwarnings('ignore')
def make_disciplines_list(data):
results = []
results.clear()
data = str(data).replace('[SEP]', ',')
disciplines_list = data.split(',')
for discipline_id in disciplines_list:
if 'nan' not in discipline_id:
discipline_id = discipline_id.replace("'", ''). \
replace('{', ''). replace('}', '').replace(' ', '')
if is_number(discipline_id):
discipline_id = round(float(discipline_id))
if discipline_id not in results:
results.append(discipline_id)
return results
def build_user_sel_discipline_dict():
"""
构建用户选择课程的所属学科的字典
"""
logger.info("加载选课行为数据")
all_select_df = get_all_select_df(offline=False)
logger.info("获取实践项目信息数据")
item_info_df = get_item_info_df()
logger.info('生成用户选课记录列表')
all_select_df = all_select_df.merge(item_info_df, on='shixun_id')
all_select_df = all_select_df.groupby('user_id').progress_aggregate(set).reset_index()
user_sel_subject_df = pd.DataFrame()
user_sel_subject_df['user_id'] = all_select_df['user_id']
user_sel_subject_df['disciplines_list'] = all_select_df['disciplines_id'].progress_apply(make_disciplines_list)
user_sel_discipline_dict = dict(zip(user_sel_subject_df['user_id'], user_sel_subject_df['disciplines_list']))
pickle.dump(user_sel_discipline_dict, open(shixun_cold_start_user_shixun_dict, 'wb'))
return user_sel_discipline_dict
def build_cold_start_recall_dict(topk=100):
"""
按兴趣标签ID获取对应的实践项目并按业务规则进行排序方便冷启动召回时直接获取TopK
"""
# 先按兴趣标签ID进行实践项目分组
df_grouped = df_cold_start_df.groupby('disciplines_id')
all_groups = tqdm(df_grouped)
# 定义召回物品的字典
cold_start_user_items_dict = {}
logger.info(f"开始生成实践项目冷启动召回字典")
for group_id, df_group in all_groups:
if is_number(group_id):
group_id = str(int(group_id))
cold_start_user_items_dict.setdefault(int(group_id), [])
id_file = cold_start_shixuns_parent_path + str(group_id) + '.csv'
# 生成时间戳,排序时使用
df_group['created_at'].fillna('2016-01-01 00:00:00', inplace=True)
df_group['created_at_ts'] = df_group['created_at'].progress_apply(lambda x:time.mktime(time.strptime(x,'%Y-%m-%d %H:%M:%S')))
# 最热的实训按选择人数降序,难度升序
df_group_hottest = df_group.sort_values(by=['myshixuns_count', 'trainee'],
axis=0, ascending=[True, False], inplace=False)
df_group_hottest.drop_duplicates(['shixun_id'], inplace=True)
# 最新的创建时间降序,难度升序
df_group_newtest = df_group.sort_values(by=['created_at_ts', 'trainee'],
axis=0, ascending=[True, False], inplace=False)
df_group_newtest.drop_duplicates(['shixun_id'], inplace=True)
# 最新和最热各取一半
df_recall = df_group_hottest[: topk // 2]
df_recall = df_recall.append(df_group_newtest[: topk // 2])
df_recall.drop_duplicates(['shixun_id'], inplace=True)
# 随机打乱顺序
df_recall.sample(frac = 1).reset_index(drop = True)
# 生成每个兴趣标签下的冷启动召回字典
for shixun_id, shixun_name in zip(df_recall['shixun_id'], df_recall['shixun_name']):
cold_start_user_items_dict[int(group_id)].append((shixun_id, shixun_name))
# 保存每个兴趣标签对应的实践项目
df_recall.to_csv(id_file, columns=['shixun_id', 'shixun_name', 'visits',
'myshixuns_count', 'trainee', 'averge_star', 'created_at', 'updated_at'],
sep='\t', index=False)
# 保存生成的冷启动召回字典
pickle.dump(cold_start_user_items_dict, open(shixun_cold_start_recall_dict, 'wb'))
return cold_start_user_items_dict
logger.info('加载冷启动召回数据')
df_cold_start_df = pd.read_csv(cold_start_shixuns_data_path, sep='\t', encoding='utf-8')
df_cold_start_df.fillna('-1', inplace=True)
if os.path.exists(shixun_cold_start_user_shixun_dict):
logger.info('加载用户选择课程所属学科字典')
cold_start_user_subject_dict = pickle.load(open(shixun_cold_start_user_shixun_dict, 'rb'))
else:
logger.info('生成用户选择课程所属学科字典')
cold_start_user_subject_dict = build_user_sel_discipline_dict()
if os.path.exists(shixun_cold_start_recall_dict):
logger.info('加载用户冷启动召回字典')
cold_start_items_dict = pickle.load(open(shixun_cold_start_recall_dict, 'rb'))
else:
logger.info('生成用户冷启动召回字典')
cold_start_items_dict = build_cold_start_recall_dict(topk=200)
def cold_start_user_recall(user_id, disciplines_id_list, topk=100):
"""
用户冷启动召回推荐
:param disciplines_id_list: 课程大类ID列表
:param topk: 需要召回的数量
:return 召回的字典
"""
rank_list = []
user_disciplines_id_list = []
user_disciplines_id_list.clear()
# 获取用户选择的课程所属的学科列表
if user_id in cold_start_user_subject_dict:
user_disciplines_id_list = list(cold_start_user_subject_dict[user_id])
# 用户没有兴趣标签时随机推荐所有标签下最新和最热的
if (disciplines_id_list is None) or (len(disciplines_id_list) == 0):
# 也没有用户的选课行为时,随机推荐所有标签下最新和最热的
if len(user_disciplines_id_list) == 0:
disciplines_id_list = df_cold_start_df['disciplines_id'].unique().tolist()
# 有用户选课行为时推荐实训所属学科最新和最热的实训
else:
disciplines_id_list = user_disciplines_id_list
else:
# 传递了兴趣学科标签时和用户选过的学科一起随机推荐最新和最热的实训
if len(disciplines_id_list) > 0 and len(user_disciplines_id_list) > 0:
disciplines_id_list = disciplines_id_list + user_disciplines_id_list
# 获取每个兴趣标签下的召回
for disciplines_id in set(disciplines_id_list):
disciplines_id = int(disciplines_id)
if disciplines_id in cold_start_items_dict:
rank_list.append(cold_start_items_dict[disciplines_id])
# 二维列表展成一维列表
rank_list = [item for row_item in rank_list for item in row_item]
if topk > len(rank_list):
topk = len(rank_list)
# 多个兴趣标签冷启动召回的再随机打乱一次
random.shuffle(rank_list)
# 取topk个返回
rank_list = rank_list[:topk]
# 以字典的形式返回
recommend_results = {shixun_id: shixun_name for shixun_id, shixun_name in rank_list}
return recommend_results
if __name__ == '__main__':
# 冷启动推荐测试
recommend_results = cold_start_user_recall(user_id=test_user_id,
disciplines_id_list=[4],
topk=100)
print(json.dumps(recommend_results, ensure_ascii=False, indent=4)) #indent=4表示打印间隔的长度