|
|
import sys
|
|
|
import os.path as osp
|
|
|
this_dir = osp.dirname(__file__)
|
|
|
lib_path = osp.join(this_dir, '..')
|
|
|
sys.path.insert(0, lib_path)
|
|
|
import torch
|
|
|
from AutoRec.trainer import Trainer
|
|
|
from AutoRec.network import AutoRec
|
|
|
from AutoRec.dataloader import dataProcess
|
|
|
|
|
|
autorec_config = \
|
|
|
{
|
|
|
'train_ratio': 0.9,
|
|
|
'num_epoch': 100,
|
|
|
'batch_size': 100,
|
|
|
'optimizer': 'adam',
|
|
|
'adam_lr': 1e-3,
|
|
|
'l2_regularization':1e-4,
|
|
|
'num_users': 6040,
|
|
|
'num_items': 3952,
|
|
|
'hidden_units': 500,
|
|
|
'lambda': 1,
|
|
|
'device_id': 0,
|
|
|
'use_cuda': True,
|
|
|
'data_file': '../Data/ml-1m/ratings.dat',
|
|
|
'model_name': '../TrainedModels/AutoRec.model'
|
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
####################################################################################
|
|
|
# AutoRec 自编码器协同过滤算法
|
|
|
####################################################################################
|
|
|
train_r, train_mask_r, test_r, test_mask_r, \
|
|
|
user_train_set, item_train_set, user_test_set, item_test_set = \
|
|
|
dataProcess(autorec_config['data_file'], autorec_config['num_users'], autorec_config['num_items'], autorec_config['train_ratio'])
|
|
|
# 实例化AutoRec对象
|
|
|
autorec = AutoRec(config=autorec_config)
|
|
|
|
|
|
####################################################################################
|
|
|
# 模型训练阶段
|
|
|
###################################################################################
|
|
|
# 实例化模型训练器
|
|
|
trainer = Trainer(model=autorec, config=autorec_config)
|
|
|
# 开始训练
|
|
|
trainer.train(train_r, train_mask_r)
|
|
|
# 保存模型
|
|
|
trainer.save()
|
|
|
|
|
|
###################################################################################
|
|
|
# 模型测试阶段
|
|
|
###################################################################################
|
|
|
# 实例化AutoRec对象
|
|
|
autorec = AutoRec(config=autorec_config)
|
|
|
autorec.loadModel(map_location=torch.device('cpu'))
|
|
|
|
|
|
# 进行性能评估
|
|
|
autorec.evaluate(test_r, test_mask_r, user_test_set=user_test_set, user_train_set=user_train_set, \
|
|
|
item_test_set=item_test_set, item_train_set=item_train_set)
|
|
|
|
|
|
# 从测试集中抽取1个用户,推荐5个商品
|
|
|
print(autorec.recommend_user(test_r[0], 5))
|
|
|
|
|
|
|
|
|
|