From 8fe299f1798e28e95a99f32350ad1837b4dd57b2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 7 Jun 2020 14:10:30 -0700 Subject: [PATCH] model fuse --- utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 00eb0cb..7eea4b4 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn): if conv.bias is not None: b_conv = conv.bias else: - b_conv = torch.zeros(conv.weight.size(0)) + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)