You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
5.1 KiB

5 months ago
import os
import sys
sys.path.append(os.getcwd())
import pickle
from tqdm import tqdm
import os
import warnings
import pandas as pd
import numpy as np
from config import shixun_features_save_path
from config import shixun_save_path
from config import shixuns_bert_emb_dict
from config import shixuns_bert_em_path
tqdm.pandas()
warnings.filterwarnings('ignore')
def save_rank_results(recall_df, topk=5, model_name=None):
recall_df = recall_df.sort_values(by=['user_id', 'pred_score'])
recall_df['rank'] = recall_df.groupby(['user_id'])['pred_score'].rank(ascending=False, method='first')
# 判断是不是每个用户都有topk及以上
tmp = recall_df.groupby('user_id').apply(lambda x: x['rank'].max())
assert tmp.min() >= 1
del recall_df['pred_score']
submit = recall_df[recall_df['rank'] <= topk].set_index(['user_id', 'rank']).unstack(-1).reset_index()
submit.columns = [int(col) if isinstance(col, int) else col for col in submit.columns.droplevel(0)]
item_columns = {}
item_columns[''] = 'user_id'
for i in range(1, topk):
item_columns[i] = 'item_' + str(i)
# 按照提交格式定义列名
submit = submit.rename(columns=item_columns)
save_name = shixun_features_save_path + model_name + '_results.csv'
submit.to_csv(save_name, sep='\t', index=False, header=True)
def norm_sim(sim_df, weight=0.0):
"""
排序结果归一化
"""
min_sim = sim_df.min()
max_sim = sim_df.max()
if max_sim == min_sim:
sim_df = sim_df.apply(lambda sim: 1.0)
else:
sim_df = sim_df.apply(lambda sim: 1.0 * (sim - min_sim) / (max_sim - min_sim))
sim_df = sim_df.apply(lambda sim: sim + weight)
return sim_df
def get_kfold_users(train_df, n=5):
"""
五折交叉验证这里的五折交叉是以用户为目标进行五折划分
这一部分与前面的单独训练和验证是分开的
"""
user_ids = train_df['user_id'].unique()
user_set = [user_ids[i::n] for i in range(n)]
return user_set
def fill_is_trainee_hab(x):
"""
判断是否在用户选择的实训难度中
"""
result = 0
if isinstance(x.trainee_list, set):
if str(float(x.trainee)) in x.trainee_list:
result = 1
return result
def make_shixun_tuple_func(df):
row_data = []
for name, row_df in df.iterrows():
row_data.append((row_df['created_at_ts'],
row_df['trainee'],
row_df['visits'],
row_df['myshixuns_count'],
row_df['challenges_count'],
row_df['averge_star'],
row_df['task_pass']))
return row_data
def get_rank_item_info_dict(item_info_df):
"""
生成物品信息字典
"""
item_info_df['created_at_ts'] = item_info_df['created_at_ts'].fillna(0.0)
item_info_df['created_at_ts'] = item_info_df['created_at_ts'].astype(float)
item_info_df['trainee'] = item_info_df['trainee'].fillna(0)
item_info_df['trainee'] = item_info_df['trainee'].astype(int)
item_info_df['visits'] = item_info_df['visits'].fillna(0)
item_info_df['visits'] = item_info_df['visits'].astype(int)
item_info_df['myshixuns_count'] = item_info_df['myshixuns_count'].fillna(0)
item_info_df['myshixuns_count'] = item_info_df['myshixuns_count'].astype(int)
item_info_df['challenges_count'] = item_info_df['challenges_count'].fillna(0)
item_info_df['challenges_count'] = item_info_df['challenges_count'].astype(int)
item_info_df['averge_star'] = item_info_df['averge_star'].fillna(0.0)
item_info_df['averge_star'] = item_info_df['averge_star'].astype(float)
item_info_df['task_pass'] = item_info_df['task_pass'].fillna(0)
item_info_df['task_pass'] = item_info_df['task_pass'].astype(int)
if os.path.exists(shixun_save_path + 'shixuns_info_rank_dict.pkl'):
item_info_dict = pickle.load(open(shixun_save_path + 'shixuns_info_rank_dict.pkl', 'rb'))
else:
item_info_tuples = item_info_df.groupby('shixun_id').progress_apply(make_shixun_tuple_func).reset_index()
item_info_dict = dict(zip(item_info_tuples['shixun_id'], item_info_tuples[0]))
pickle.dump(item_info_dict, open(shixun_save_path + 'shixuns_info_rank_dict.pkl', 'wb'))
return item_info_dict
def get_item_bert_emb_dict():
"""
生成和读取物品的Bert Embedding数据
"""
# 加载已经保存的Embedding数据
if os.path.exists(shixuns_bert_emb_dict):
item_emb_dict = pickle.load(open(shixuns_bert_emb_dict, 'rb'))
return item_emb_dict
# 生成物品的Embedding数据
item_emb_df = pd.read_csv(shixuns_bert_em_path, sep='\t', encoding='utf-8')
item_emb_cols = [x for x in item_emb_df.columns if 'bert_em' in x]
item_emb_np = np.ascontiguousarray(item_emb_df[item_emb_cols])
# 进行归一化
item_emb_np = item_emb_np / np.linalg.norm(item_emb_np, axis=1, keepdims=True)
item_emb_dict = dict(zip(item_emb_df['shixun_id'], item_emb_np))
pickle.dump(item_emb_dict, open(shixuns_bert_emb_dict, 'wb'))
return item_emb_dict