|
|
|
@ -52,7 +52,8 @@ class Model(nn.Module):
|
|
|
|
|
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
|
|
|
|
|
|
|
|
|
# Define model
|
|
|
|
|
if nc:
|
|
|
|
|
if nc and nc != self.md['nc']:
|
|
|
|
|
print('Overriding %s nc=%g with nc=%g' % (model_cfg, self.md['nc'], nc))
|
|
|
|
|
self.md['nc'] = nc # override yaml value
|
|
|
|
|
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
|
|
|
|
|
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
|
|
|
|