You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.2 KiB

3 years ago
import sys
import os.path as osp
this_dir = osp.dirname(__file__)
lib_path = osp.join(this_dir, '..')
sys.path.insert(0, lib_path)
import torch
from AutoRec.trainer import Trainer
from AutoRec.network import AutoRec
from AutoRec.dataloader import dataProcess
autorec_config = \
{
'train_ratio': 0.9,
'num_epoch': 100,
'batch_size': 100,
'optimizer': 'adam',
'adam_lr': 1e-3,
'l2_regularization':1e-4,
'num_users': 6040,
'num_items': 3952,
'hidden_units': 500,
'lambda': 1,
'device_id': 0,
'use_cuda': True,
'data_file': '../Data/ml-1m/ratings.dat',
'model_name': '../TrainedModels/AutoRec.model'
}
if __name__ == "__main__":
####################################################################################
# AutoRec 自编码器协同过滤算法
####################################################################################
train_r, train_mask_r, test_r, test_mask_r, \
user_train_set, item_train_set, user_test_set, item_test_set = \
dataProcess(autorec_config['data_file'], autorec_config['num_users'], autorec_config['num_items'], autorec_config['train_ratio'])
# 实例化AutoRec对象
autorec = AutoRec(config=autorec_config)
####################################################################################
# 模型训练阶段
###################################################################################
# 实例化模型训练器
trainer = Trainer(model=autorec, config=autorec_config)
# 开始训练
trainer.train(train_r, train_mask_r)
# 保存模型
trainer.save()
###################################################################################
# 模型测试阶段
###################################################################################
# 实例化AutoRec对象
autorec = AutoRec(config=autorec_config)
autorec.loadModel(map_location=torch.device('cpu'))
# 进行性能评估
autorec.evaluate(test_r, test_mask_r, user_test_set=user_test_set, user_train_set=user_train_set, \
item_test_set=item_test_set, item_train_set=item_train_set)
# 从测试集中抽取1个用户推荐5个商品
print(autorec.recommend_user(test_r[0], 5))