You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.9 KiB

4 years ago
import sys
import os.path as osp
this_dir = osp.dirname(__file__)
lib_path = osp.join(this_dir, '..')
sys.path.insert(0, lib_path)
import torch
from DeepCrossing.trainer import Trainer
from DeepCrossing.network import DeepCrossing
from DeepCrossing.criteo_loader import getTestData, getTrainData
import torch.utils.data as Data
deepcrossing_config = \
{
'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小
'min_dim': 256, # 稀疏特征维度小于min_dim的直接进入stack layer不用经过embedding层
'hidden_layers': [512,256,128,64,32],
'num_epoch': 100,
'batch_size': 32,
'lr': 1e-3,
'l2_regularization': 1e-4,
'device_id': 0,
'use_cuda': True,
'train_file': '../Data/criteo/processed_data/train_set.csv',
'fea_file': '../Data/criteo/processed_data/fea_col.npy',
'validate_file': '../Data/criteo/processed_data/val_set.csv',
'test_file': '../Data/criteo/processed_data/test_set.csv',
'model_name': '../TrainedModels/DeepCrossing.model'
}
if __name__ == "__main__":
####################################################################################
# DeepCrossing 模型
####################################################################################
training_data, training_label, dense_features_col, sparse_features_col = getTrainData(deepcrossing_config['train_file'], deepcrossing_config['fea_file'])
train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())
test_data = getTestData(deepcrossing_config['test_file'])
test_dataset = Data.TensorDataset(torch.tensor(test_data).float())
deepCrossing = DeepCrossing(deepcrossing_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)
####################################################################################
# 模型训练阶段
####################################################################################
# 实例化模型训练器
trainer = Trainer(model=deepCrossing, config=deepcrossing_config)
# 训练
trainer.train(train_dataset)
# 保存模型
trainer.save()
####################################################################################
# 模型测试阶段
####################################################################################
deepCrossing.eval()
if deepcrossing_config['use_cuda']:
deepCrossing.loadModel(map_location=lambda storage, loc: storage.cuda(deepcrossing_config['device_id']))
deepCrossing = deepCrossing.cuda()
else:
deepCrossing.loadModel(map_location=torch.device('cpu'))
y_pred_probs = deepCrossing(torch.tensor(test_data).float().cuda())
y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
print("Test Data CTR Predict...\n ", y_pred.view(-1))