@ -9,7 +9,7 @@ import torch.nn as nn
from models . common import Conv , Bottleneck , SPP , DWConv , Focus , BottleneckCSP , Concat
from models . common import Conv , Bottleneck , SPP , DWConv , Focus , BottleneckCSP , Concat
from models . experimental import MixConv2d , CrossConv , C3
from models . experimental import MixConv2d , CrossConv , C3
from utils . general import check_anchor_order , make_divisible , check_file
from utils . general import check_anchor_order , make_divisible , check_file , set_logging
from utils . torch_utils import (
from utils . torch_utils import (
time_synchronized , fuse_conv_and_bn , model_info , scale_img , initialize_weights , select_device )
time_synchronized , fuse_conv_and_bn , model_info , scale_img , initialize_weights , select_device )
@ -156,7 +156,7 @@ class Model(nn.Module):
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
def fuse ( self ) : # fuse model Conv2d() + BatchNorm2d() layers
def fuse ( self ) : # fuse model Conv2d() + BatchNorm2d() layers
print ( ' Fusing layers... ' , end = ' ' )
print ( ' Fusing layers... ' )
for m in self . model . modules ( ) :
for m in self . model . modules ( ) :
if type ( m ) is Conv :
if type ( m ) is Conv :
m . _non_persistent_buffers_set = set ( ) # pytorch 1.6.0 compatability
m . _non_persistent_buffers_set = set ( ) # pytorch 1.6.0 compatability
@ -239,6 +239,7 @@ if __name__ == '__main__':
parser . add_argument ( ' --device ' , default = ' ' , help = ' cuda device, i.e. 0 or 0,1,2,3 or cpu ' )
parser . add_argument ( ' --device ' , default = ' ' , help = ' cuda device, i.e. 0 or 0,1,2,3 or cpu ' )
opt = parser . parse_args ( )
opt = parser . parse_args ( )
opt . cfg = check_file ( opt . cfg ) # check file
opt . cfg = check_file ( opt . cfg ) # check file
set_logging ( )
device = select_device ( opt . device )
device = select_device ( opt . device )
# Create model
# Create model