You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import sys
import os.path as osp
this_dir = osp.dirname(__file__)
lib_path = osp.join(this_dir, '..')
sys.path.insert(0, lib_path)
import torch
from DeepCrossing.trainer import Trainer
from DeepCrossing.network import DeepCrossing
from DeepCrossing.criteo_loader import getTestData, getTrainData
import torch.utils.data as Data
deepcrossing_config = \
{
'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小
'min_dim': 256, # 稀疏特征维度小于min_dim的直接进入stack layer不用经过embedding层
'hidden_layers': [512,256,128,64,32],
'num_epoch': 100,
'batch_size': 32,
'lr': 1e-3,
'l2_regularization': 1e-4,
'device_id': 0,
'use_cuda': True,
'train_file': '../Data/criteo/processed_data/train_set.csv',
'fea_file': '../Data/criteo/processed_data/fea_col.npy',
'validate_file': '../Data/criteo/processed_data/val_set.csv',
'test_file': '../Data/criteo/processed_data/test_set.csv',
'model_name': '../TrainedModels/DeepCrossing.model'
}
if __name__ == "__main__":
####################################################################################
# DeepCrossing 模型
####################################################################################
training_data, training_label, dense_features_col, sparse_features_col = getTrainData(deepcrossing_config['train_file'], deepcrossing_config['fea_file'])
train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())
test_data = getTestData(deepcrossing_config['test_file'])
test_dataset = Data.TensorDataset(torch.tensor(test_data).float())
deepCrossing = DeepCrossing(deepcrossing_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)
####################################################################################
# 模型训练阶段
####################################################################################
# 实例化模型训练器
trainer = Trainer(model=deepCrossing, config=deepcrossing_config)
# 训练
trainer.train(train_dataset)
# 保存模型
trainer.save()
####################################################################################
# 模型测试阶段
####################################################################################
deepCrossing.eval()
if deepcrossing_config['use_cuda']:
deepCrossing.loadModel(map_location=lambda storage, loc: storage.cuda(deepcrossing_config['device_id']))
deepCrossing = deepCrossing.cuda()
else:
deepCrossing.loadModel(map_location=torch.device('cpu'))
y_pred_probs = deepCrossing(torch.tensor(test_data).float().cuda())
y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
print("Test Data CTR Predict...\n ", y_pred.view(-1))