commit
519a23ca50
@ -0,0 +1,29 @@
|
||||
# pytorch_deephash
|
||||
|
||||
## Introduction
|
||||
|
||||
This is the Pytorch implementation of [Deep Learning of Binary Hash Codes for Fast Image Retrieval](https://github.com/kevinlin311tw/caffe-cvprw15), and can achieve more than 93% mAP in CIFAR10 dataset.
|
||||
|
||||
## Environment
|
||||
|
||||
> Pytorch 0.4.0
|
||||
|
||||
> torchvision 0.2.1
|
||||
|
||||
## Training
|
||||
|
||||
```python
|
||||
python train.py
|
||||
```
|
||||
|
||||
You will get trained models in model folder by default, and models' names are their test accuracy.
|
||||
|
||||
## Evaluation
|
||||
|
||||
```shell
|
||||
python mAP.py --pretrained {your saved model name in model folder by default}
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
There are some other args, which you can get them by adding '-h' or reading the code.
|
Binary file not shown.
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from net import AlexNetPlusLatent
|
||||
|
||||
from timeit import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchvision import datasets, models, transforms
|
||||
from torch.autograd import Variable
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim.lr_scheduler
|
||||
|
||||
parser = argparse.ArgumentParser(description='Deep Hashing evaluate mAP')
|
||||
parser.add_argument('--pretrained', type=str, default=92, metavar='pretrained_model',
|
||||
help='loading pretrained model(default = None)')
|
||||
parser.add_argument('--bits', type=int, default=48, metavar='bts',
|
||||
help='binary bits')
|
||||
parser.add_argument('--path', type=str, default='model', metavar='P',
|
||||
help='path directory')
|
||||
args = parser.parse_args()
|
||||
|
||||
def load_data():
|
||||
transform_train = transforms.Compose(
|
||||
[transforms.Resize(227),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
|
||||
transform_test = transforms.Compose(
|
||||
[transforms.Resize(227),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
|
||||
trainset = datasets.CIFAR10(root='./data', train=True, download=True,
|
||||
transform=transform_train)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
|
||||
shuffle=False, num_workers=2)
|
||||
|
||||
testset = datasets.CIFAR10(root='./data', train=False, download=True,
|
||||
transform=transform_test)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
|
||||
shuffle=False, num_workers=2)
|
||||
return trainloader, testloader
|
||||
|
||||
def binary_output(dataloader):
|
||||
net = AlexNetPlusLatent(args.bits)
|
||||
net.load_state_dict(torch.load('./{}/{}'.format(args.path, args.pretrained)))
|
||||
use_cuda = torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
full_batch_output = torch.cuda.FloatTensor()
|
||||
full_batch_label = torch.cuda.LongTensor()
|
||||
net.eval()
|
||||
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
||||
if use_cuda:
|
||||
inputs, targets = inputs.cuda(), targets.cuda()
|
||||
inputs, targets = Variable(inputs), Variable(targets)
|
||||
outputs, _ = net(inputs)
|
||||
full_batch_output = torch.cat((full_batch_output, outputs.data), 0)
|
||||
full_batch_label = torch.cat((full_batch_label, targets.data), 0)
|
||||
return torch.round(full_batch_output), full_batch_label
|
||||
|
||||
def precision(trn_binary, trn_label, tst_binary, tst_label):
|
||||
trn_binary = trn_binary.cpu().numpy()
|
||||
trn_binary = np.asarray(trn_binary, np.int32)
|
||||
trn_label = trn_label.cpu().numpy()
|
||||
tst_binary = tst_binary.cpu().numpy()
|
||||
tst_binary = np.asarray(tst_binary, np.int32)
|
||||
tst_label = tst_label.cpu().numpy()
|
||||
classes = np.max(tst_label) + 1
|
||||
|
||||
# 写法冗余
|
||||
for i in range(classes):
|
||||
if i == 0:
|
||||
tst_sample_binary = tst_binary[np.random.RandomState(seed=i).permutation(np.where(tst_label==i)[0])[:100]]
|
||||
tst_sample_label = np.array([i]).repeat(100)
|
||||
continue
|
||||
else:
|
||||
tst_sample_binary = np.concatenate([tst_sample_binary, tst_binary[np.random.RandomState(seed=i).permutation(np.where(tst_label==i)[0])[:100]]])
|
||||
tst_sample_label = np.concatenate([tst_sample_label, np.array([i]).repeat(100)])
|
||||
|
||||
"""
|
||||
|
||||
for i in range(classes):
|
||||
tst
|
||||
"""
|
||||
query_times = tst_sample_binary.shape[0]
|
||||
trainset_len = trn_binary.shape[0]
|
||||
AP = np.zeros(query_times)
|
||||
precision_radius = np.zeros(query_times)
|
||||
Ns = np.arange(1, trainset_len + 1)
|
||||
sum_tp = np.zeros(trainset_len)
|
||||
for i in range(query_times):
|
||||
print('Query ', i+1)
|
||||
query_label = tst_sample_label[i]
|
||||
query_binary = tst_sample_binary[i,:]
|
||||
query_result = np.count_nonzero(query_binary != trn_binary, axis=1) #don't need to divide binary length
|
||||
sort_indices = np.argsort(query_result)
|
||||
buffer_yes = np.equal(query_label, trn_label[sort_indices]).astype(int)
|
||||
P = np.cumsum(buffer_yes) / Ns
|
||||
precision_radius[i] = P[np.where(np.sort(query_result)>2)[0][0]-1]
|
||||
AP[i] = np.sum(P * buffer_yes) /sum(buffer_yes)
|
||||
sum_tp = sum_tp + np.cumsum(buffer_yes)
|
||||
precision_at_k = sum_tp / Ns / query_times
|
||||
index = [100, 200, 400, 600, 800, 1000]
|
||||
index = [i - 1 for i in index]
|
||||
print('precision at k:', precision_at_k[index])
|
||||
np.save('precision_at_k', precision_at_k)
|
||||
print('precision within Hamming radius 2:', np.mean(precision_radius))
|
||||
map = np.mean(AP)
|
||||
print('mAP:', map)
|
||||
|
||||
|
||||
|
||||
if os.path.exists('./result/train_binary') and os.path.exists('./result/train_label') and \
|
||||
os.path.exists('./result/test_binary') and os.path.exists('./result/test_label') and args.pretrained == 0:
|
||||
train_binary = torch.load('./result/train_binary')
|
||||
train_label = torch.load('./result/train_label')
|
||||
test_binary = torch.load('./result/test_binary')
|
||||
test_label = torch.load('./result/test_label')
|
||||
|
||||
else:
|
||||
trainloader, testloader = load_data()
|
||||
train_binary, train_label = binary_output(trainloader)
|
||||
test_binary, test_label = binary_output(testloader)
|
||||
if not os.path.isdir('result'):
|
||||
os.mkdir('result')
|
||||
torch.save(train_binary, './result/train_binary')
|
||||
torch.save(train_label, './result/train_label')
|
||||
torch.save(test_binary, './result/test_binary')
|
||||
torch.save(test_label, './result/test_label')
|
||||
|
||||
|
||||
precision(train_binary, train_label, test_binary, test_label)
|
@ -0,0 +1,109 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import hamming, cdist
|
||||
from net import AlexNetPlusLatent
|
||||
|
||||
from timeit import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchvision import datasets, models, transforms
|
||||
from torch.autograd import Variable
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim.lr_scheduler
|
||||
|
||||
parser = argparse.ArgumentParser(description='Deep Hashing evaluate mAP')
|
||||
parser.add_argument('--pretrained', type=int, default=92, metavar='pretrained_model',
|
||||
help='loading pretrained model(default = None)')
|
||||
parser.add_argument('--bits', type=int, default=48, metavar='bts',
|
||||
help='binary bits')
|
||||
args = parser.parse_args()
|
||||
|
||||
def load_data():
|
||||
transform_train = transforms.Compose(
|
||||
[transforms.Scale(227),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
|
||||
transform_test = transforms.Compose(
|
||||
[transforms.Scale(227),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
|
||||
trainset = datasets.CIFAR10(root='./data', train=True, download=True,
|
||||
transform=transform_train)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
|
||||
shuffle=False, num_workers=2)
|
||||
|
||||
testset = datasets.CIFAR10(root='./data', train=False, download=True,
|
||||
transform=transform_test)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
|
||||
shuffle=False, num_workers=2)
|
||||
return trainloader, testloader
|
||||
|
||||
def binary_output(dataloader):
|
||||
net = AlexNetPlusLatent(args.bits)
|
||||
net.load_state_dict(torch.load('./model/%d' %args.pretrained))
|
||||
use_cuda = torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
full_batch_output = torch.cuda.FloatTensor()
|
||||
full_batch_label = torch.cuda.LongTensor()
|
||||
net.eval()
|
||||
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
||||
if use_cuda:
|
||||
inputs, targets = inputs.cuda(), targets.cuda()
|
||||
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
|
||||
outputs, _ = net(inputs)
|
||||
full_batch_output = torch.cat((full_batch_output, outputs.data), 0)
|
||||
full_batch_label = torch.cat((full_batch_label, targets.data), 0)
|
||||
return torch.round(full_batch_output), full_batch_label
|
||||
|
||||
def precision(trn_binary, trn_label, tst_binary, tst_label):
|
||||
trn_binary = trn_binary.cpu().numpy()
|
||||
trn_binary = np.asarray(trn_binary, np.int32)
|
||||
trn_label = trn_label.cpu().numpy()
|
||||
tst_binary = tst_binary.cpu().numpy()
|
||||
tst_binary = np.asarray(tst_binary, np.int32)
|
||||
tst_label = tst_label.cpu().numpy()
|
||||
query_times = tst_binary.shape[0]
|
||||
trainset_len = train_binary.shape[0]
|
||||
AP = np.zeros(query_times)
|
||||
Ns = np.arange(1, trainset_len + 1)
|
||||
total_time_start = time.time()
|
||||
for i in range(query_times):
|
||||
print('Query ', i+1)
|
||||
query_label = tst_label[i]
|
||||
query_binary = tst_binary[i,:]
|
||||
query_result = np.count_nonzero(query_binary != trn_binary, axis=1) #don't need to divide binary length
|
||||
sort_indices = np.argsort(query_result)
|
||||
buffer_yes= np.equal(query_label, trn_label[sort_indices]).astype(int)
|
||||
P = np.cumsum(buffer_yes) / Ns
|
||||
AP[i] = np.sum(P * buffer_yes) /sum(buffer_yes)
|
||||
map = np.mean(AP)
|
||||
print(map)
|
||||
print('total query time = ', time.time() - total_time_start)
|
||||
|
||||
|
||||
|
||||
if os.path.exists('./result/train_binary') and os.path.exists('./result/train_label') and \
|
||||
os.path.exists('./result/test_binary') and os.path.exists('./result/test_label') and args.pretrained == 0:
|
||||
train_binary = torch.load('./result/train_binary')
|
||||
train_label = torch.load('./result/train_label')
|
||||
test_binary = torch.load('./result/test_binary')
|
||||
test_label = torch.load('./result/test_label')
|
||||
|
||||
else:
|
||||
trainloader, testloader = load_data()
|
||||
train_binary, train_label = binary_output(trainloader)
|
||||
test_binary, test_label = binary_output(testloader)
|
||||
if not os.path.isdir('result'):
|
||||
os.mkdir('result')
|
||||
torch.save(train_binary, './result/train_binary')
|
||||
torch.save(train_label, './result/train_label')
|
||||
torch.save(test_binary, './result/test_binary')
|
||||
torch.save(test_label, './result/test_label')
|
||||
|
||||
|
||||
precision(train_binary, train_label, test_binary, test_label)
|
@ -0,0 +1,26 @@
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
# pre-trained alex net model
|
||||
alexnet_model = models.alexnet(pretrained=True)
|
||||
|
||||
# nn.Module: Base class for all neural network modules.
|
||||
# Custom class should also subclass this class
|
||||
class AlexNetPlusLatent(nn.Module):
|
||||
def __init__(self, bits):
|
||||
super(AlexNetPlusLatent, self).__init__()
|
||||
self.bits = bits
|
||||
self.features = nn.Sequential(*list(alexnet_model.features.children()))
|
||||
self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1])
|
||||
self.Linear1 = nn.Linear(4096, self.bits)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.Linear2 = nn.Linear(self.bits, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), 256 * 6 * 6)
|
||||
x = self.remain(x)
|
||||
x = self.Linear1(x)
|
||||
features = self.sigmoid(x)
|
||||
result = self.Linear2(features)
|
||||
return features, result
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue