|
|
|
@ -12,6 +12,7 @@ import torchvision.models as models
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_seeds(seed=0):
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
|
@ -144,7 +145,8 @@ def model_info(model, verbose=False):
|
|
|
|
|
except:
|
|
|
|
|
fs = ''
|
|
|
|
|
|
|
|
|
|
logger.info('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
|
|
|
|
|
logger.info(
|
|
|
|
|
'Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_classifier(name='resnet101', n=2):
|
|
|
|
|