You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
4.4 KiB
110 lines
4.4 KiB
5 years ago
|
import os
|
||
|
import argparse
|
||
|
|
||
|
import numpy as np
|
||
|
from scipy.spatial.distance import hamming, cdist
|
||
|
from net import AlexNetPlusLatent
|
||
|
|
||
|
from timeit import time
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from torchvision import datasets, models, transforms
|
||
|
from torch.autograd import Variable
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
import torch.optim.lr_scheduler
|
||
|
|
||
|
parser = argparse.ArgumentParser(description='Deep Hashing evaluate mAP')
|
||
|
parser.add_argument('--pretrained', type=int, default=92, metavar='pretrained_model',
|
||
|
help='loading pretrained model(default = None)')
|
||
|
parser.add_argument('--bits', type=int, default=48, metavar='bts',
|
||
|
help='binary bits')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
def load_data():
|
||
|
transform_train = transforms.Compose(
|
||
|
[transforms.Scale(227),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
|
||
|
transform_test = transforms.Compose(
|
||
|
[transforms.Scale(227),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
|
||
|
trainset = datasets.CIFAR10(root='./data', train=True, download=True,
|
||
|
transform=transform_train)
|
||
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
|
||
|
shuffle=False, num_workers=2)
|
||
|
|
||
|
testset = datasets.CIFAR10(root='./data', train=False, download=True,
|
||
|
transform=transform_test)
|
||
|
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
|
||
|
shuffle=False, num_workers=2)
|
||
|
return trainloader, testloader
|
||
|
|
||
|
def binary_output(dataloader):
|
||
|
net = AlexNetPlusLatent(args.bits)
|
||
|
net.load_state_dict(torch.load('./model/%d' %args.pretrained))
|
||
|
use_cuda = torch.cuda.is_available()
|
||
|
if use_cuda:
|
||
|
net.cuda()
|
||
|
full_batch_output = torch.cuda.FloatTensor()
|
||
|
full_batch_label = torch.cuda.LongTensor()
|
||
|
net.eval()
|
||
|
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
||
|
if use_cuda:
|
||
|
inputs, targets = inputs.cuda(), targets.cuda()
|
||
|
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
|
||
|
outputs, _ = net(inputs)
|
||
|
full_batch_output = torch.cat((full_batch_output, outputs.data), 0)
|
||
|
full_batch_label = torch.cat((full_batch_label, targets.data), 0)
|
||
|
return torch.round(full_batch_output), full_batch_label
|
||
|
|
||
|
def precision(trn_binary, trn_label, tst_binary, tst_label):
|
||
|
trn_binary = trn_binary.cpu().numpy()
|
||
|
trn_binary = np.asarray(trn_binary, np.int32)
|
||
|
trn_label = trn_label.cpu().numpy()
|
||
|
tst_binary = tst_binary.cpu().numpy()
|
||
|
tst_binary = np.asarray(tst_binary, np.int32)
|
||
|
tst_label = tst_label.cpu().numpy()
|
||
|
query_times = tst_binary.shape[0]
|
||
|
trainset_len = train_binary.shape[0]
|
||
|
AP = np.zeros(query_times)
|
||
|
Ns = np.arange(1, trainset_len + 1)
|
||
|
total_time_start = time.time()
|
||
|
for i in range(query_times):
|
||
|
print('Query ', i+1)
|
||
|
query_label = tst_label[i]
|
||
|
query_binary = tst_binary[i,:]
|
||
|
query_result = np.count_nonzero(query_binary != trn_binary, axis=1) #don't need to divide binary length
|
||
|
sort_indices = np.argsort(query_result)
|
||
|
buffer_yes= np.equal(query_label, trn_label[sort_indices]).astype(int)
|
||
|
P = np.cumsum(buffer_yes) / Ns
|
||
|
AP[i] = np.sum(P * buffer_yes) /sum(buffer_yes)
|
||
|
map = np.mean(AP)
|
||
|
print(map)
|
||
|
print('total query time = ', time.time() - total_time_start)
|
||
|
|
||
|
|
||
|
|
||
|
if os.path.exists('./result/train_binary') and os.path.exists('./result/train_label') and \
|
||
|
os.path.exists('./result/test_binary') and os.path.exists('./result/test_label') and args.pretrained == 0:
|
||
|
train_binary = torch.load('./result/train_binary')
|
||
|
train_label = torch.load('./result/train_label')
|
||
|
test_binary = torch.load('./result/test_binary')
|
||
|
test_label = torch.load('./result/test_label')
|
||
|
|
||
|
else:
|
||
|
trainloader, testloader = load_data()
|
||
|
train_binary, train_label = binary_output(trainloader)
|
||
|
test_binary, test_label = binary_output(testloader)
|
||
|
if not os.path.isdir('result'):
|
||
|
os.mkdir('result')
|
||
|
torch.save(train_binary, './result/train_binary')
|
||
|
torch.save(train_label, './result/train_label')
|
||
|
torch.save(test_binary, './result/test_binary')
|
||
|
torch.save(test_label, './result/test_label')
|
||
|
|
||
|
|
||
|
precision(train_binary, train_label, test_binary, test_label)
|