|
|
|
@ -12,6 +12,7 @@ import torchvision.models as models
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_seeds(seed=0):
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
|
@ -43,7 +44,7 @@ def select_device(device='', batch_size=None):
|
|
|
|
|
if i == 1:
|
|
|
|
|
s = ' ' * len(s)
|
|
|
|
|
logger.info("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
|
|
|
|
|
(s, i, x[i].name, x[i].total_memory / c))
|
|
|
|
|
(s, i, x[i].name, x[i].total_memory / c))
|
|
|
|
|
else:
|
|
|
|
|
logger.info('Using CPU')
|
|
|
|
|
|
|
|
|
@ -144,7 +145,8 @@ def model_info(model, verbose=False):
|
|
|
|
|
except:
|
|
|
|
|
fs = ''
|
|
|
|
|
|
|
|
|
|
logger.info('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
|
|
|
|
|
logger.info(
|
|
|
|
|
'Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_classifier(name='resnet101', n=2):
|
|
|
|
|