You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
835 B
30 lines
835 B
# -*- coding: utf-8 -*-
|
|
# File : unittest.py
|
|
# Author : Jiayuan Mao
|
|
# Email : maojiayuan@gmail.com
|
|
# Date : 27/01/2018
|
|
#
|
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
|
# Distributed under MIT License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from torch.autograd import Variable
|
|
|
|
|
|
def as_numpy(v):
|
|
if isinstance(v, Variable):
|
|
v = v.data
|
|
return v.cpu().numpy()
|
|
|
|
|
|
class TorchTestCase(unittest.TestCase):
|
|
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
|
npa, npb = as_numpy(a), as_numpy(b)
|
|
self.assertTrue(
|
|
np.allclose(npa, npb, atol=atol),
|
|
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
|
)
|