You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
5.1 KiB

import os
import sys
sys.path.append(os.getcwd())
import pickle
from tqdm import tqdm
import os
import warnings
import pandas as pd
import numpy as np
from config import shixun_features_save_path
from config import shixun_save_path
from config import shixuns_bert_emb_dict
from config import shixuns_bert_em_path
tqdm.pandas()
warnings.filterwarnings('ignore')
def save_rank_results(recall_df, topk=5, model_name=None):
recall_df = recall_df.sort_values(by=['user_id', 'pred_score'])
recall_df['rank'] = recall_df.groupby(['user_id'])['pred_score'].rank(ascending=False, method='first')
# 判断是不是每个用户都有topk及以上
tmp = recall_df.groupby('user_id').apply(lambda x: x['rank'].max())
assert tmp.min() >= 1
del recall_df['pred_score']
submit = recall_df[recall_df['rank'] <= topk].set_index(['user_id', 'rank']).unstack(-1).reset_index()
submit.columns = [int(col) if isinstance(col, int) else col for col in submit.columns.droplevel(0)]
item_columns = {}
item_columns[''] = 'user_id'
for i in range(1, topk):
item_columns[i] = 'item_' + str(i)
# 按照提交格式定义列名
submit = submit.rename(columns=item_columns)
save_name = shixun_features_save_path + model_name + '_results.csv'
submit.to_csv(save_name, sep='\t', index=False, header=True)
def norm_sim(sim_df, weight=0.0):
"""
排序结果归一化
"""
min_sim = sim_df.min()
max_sim = sim_df.max()
if max_sim == min_sim:
sim_df = sim_df.apply(lambda sim: 1.0)
else:
sim_df = sim_df.apply(lambda sim: 1.0 * (sim - min_sim) / (max_sim - min_sim))
sim_df = sim_df.apply(lambda sim: sim + weight)
return sim_df
def get_kfold_users(train_df, n=5):
"""
五折交叉验证,这里的五折交叉是以用户为目标进行五折划分
这一部分与前面的单独训练和验证是分开的
"""
user_ids = train_df['user_id'].unique()
user_set = [user_ids[i::n] for i in range(n)]
return user_set
def fill_is_trainee_hab(x):
"""
判断是否在用户选择的实训难度中
"""
result = 0
if isinstance(x.trainee_list, set):
if str(float(x.trainee)) in x.trainee_list:
result = 1
return result
def make_shixun_tuple_func(df):
row_data = []
for name, row_df in df.iterrows():
row_data.append((row_df['created_at_ts'],
row_df['trainee'],
row_df['visits'],
row_df['myshixuns_count'],
row_df['challenges_count'],
row_df['averge_star'],
row_df['task_pass']))
return row_data
def get_rank_item_info_dict(item_info_df):
"""
生成物品信息字典
"""
item_info_df['created_at_ts'] = item_info_df['created_at_ts'].fillna(0.0)
item_info_df['created_at_ts'] = item_info_df['created_at_ts'].astype(float)
item_info_df['trainee'] = item_info_df['trainee'].fillna(0)
item_info_df['trainee'] = item_info_df['trainee'].astype(int)
item_info_df['visits'] = item_info_df['visits'].fillna(0)
item_info_df['visits'] = item_info_df['visits'].astype(int)
item_info_df['myshixuns_count'] = item_info_df['myshixuns_count'].fillna(0)
item_info_df['myshixuns_count'] = item_info_df['myshixuns_count'].astype(int)
item_info_df['challenges_count'] = item_info_df['challenges_count'].fillna(0)
item_info_df['challenges_count'] = item_info_df['challenges_count'].astype(int)
item_info_df['averge_star'] = item_info_df['averge_star'].fillna(0.0)
item_info_df['averge_star'] = item_info_df['averge_star'].astype(float)
item_info_df['task_pass'] = item_info_df['task_pass'].fillna(0)
item_info_df['task_pass'] = item_info_df['task_pass'].astype(int)
if os.path.exists(shixun_save_path + 'shixuns_info_rank_dict.pkl'):
item_info_dict = pickle.load(open(shixun_save_path + 'shixuns_info_rank_dict.pkl', 'rb'))
else:
item_info_tuples = item_info_df.groupby('shixun_id').progress_apply(make_shixun_tuple_func).reset_index()
item_info_dict = dict(zip(item_info_tuples['shixun_id'], item_info_tuples[0]))
pickle.dump(item_info_dict, open(shixun_save_path + 'shixuns_info_rank_dict.pkl', 'wb'))
return item_info_dict
def get_item_bert_emb_dict():
"""
生成和读取物品的Bert Embedding数据
"""
# 加载已经保存的Embedding数据
if os.path.exists(shixuns_bert_emb_dict):
item_emb_dict = pickle.load(open(shixuns_bert_emb_dict, 'rb'))
return item_emb_dict
# 生成物品的Embedding数据
item_emb_df = pd.read_csv(shixuns_bert_em_path, sep='\t', encoding='utf-8')
item_emb_cols = [x for x in item_emb_df.columns if 'bert_em' in x]
item_emb_np = np.ascontiguousarray(item_emb_df[item_emb_cols])
# 进行归一化
item_emb_np = item_emb_np / np.linalg.norm(item_emb_np, axis=1, keepdims=True)
item_emb_dict = dict(zip(item_emb_df['shixun_id'], item_emb_np))
pickle.dump(item_emb_dict, open(shixuns_bert_emb_dict, 'wb'))
return item_emb_dict