import argparse import logging import os import numpy as np import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from utils.data_loading import BasicDataset from unet import UNet, SEUNet, UResnet34, UResnet50, UResnet101, UResnet152 from utils.utils import plot_img_and_mask # os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img).cpu() output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear') if net.n_classes > 1: output = F.softmax(output, dim=1) # remove the first channel (background) # output = output[:, 1:, :, :] # remove the second channel (vessels) # output = output[:, :1, :, :] # output = output[:, :1, :, :] print(output.shape) mask = output.argmax(dim=1) else: mask = torch.sigmoid(output) > out_threshold return mask[0].long().squeeze().numpy() def get_args(): parser = argparse.ArgumentParser(description='Predict masks from input images') parser.add_argument('--model', '-m', default='checkpoints/checkpoint_epoch24.pth', metavar='FILE', help='Specify the file in which the model is stored') parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True) parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images') parser.add_argument('--viz', '-v', action='store_true', help='Visualize the images as they are processed') parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks') parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, help='Minimum probability value to consider a mask pixel white') parser.add_argument('--scale', '-s', type=float, default=0.5, help='Scale factor for the input images') parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') return parser.parse_args() def get_output_filenames(args): def _generate_name(fn): return f'{os.path.splitext(fn)[0]}_OUT.png' return args.output or list(map(_generate_name, args.input)) def mask_to_image(mask: np.ndarray, mask_values): if isinstance(mask_values[0], list): out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) elif mask_values == [0, 1]: out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) else: out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) if mask.ndim == 3: mask = np.argmax(mask, axis=0) for i, v in enumerate(mask_values): out[mask == i] = v return Image.fromarray(out) def MainSolve(model_path, input_files, output_files=None, visualize=False, no_save=False, mask_threshold=0.5, scale=0.5, bilinear=False, num_classes=2): logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') in_files = input_files out_files = output_files if output_files else [f"{os.path.splitext(fn)[0]}_OUT.png" for fn in input_files] net = UResnet34(n_channels=3, n_classes=num_classes, bilinear=bilinear) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Loading model {model_path}') logging.info(f'Using device {device}') net.to(device=device) state_dict = torch.load(model_path, map_location=device) mask_values = state_dict.pop('mask_values', [0, 1]) net.load_state_dict(state_dict, strict=False) logging.info('Model loaded!') for i, filename in enumerate(in_files): logging.info(f'Predicting image {filename} ...') img = Image.open(filename) true_mask = Image.open(filename.replace('image_1', 'mask')) mask = predict_img(net=net, full_img=img, scale_factor=scale, out_threshold=mask_threshold, device=device) if not no_save: out_filename = out_files[i] result = mask_to_image(mask, mask_values) result.putpalette([ 0, 0, 0, # Black background 255, 255, 255, # Class 1 0, 0, 255, # Class 2 0, 255, 0, # Class 3 ]) result.save(out_filename) logging.info(f'Mask saved to {out_filename}') plt.figure(figsize=(10, 10)) plt.subplot(1, 3, 1) plt.imshow(img) plt.axis('off') plt.title('Original Image') plt.subplot(1, 3, 2) plt.imshow(result) plt.axis('off') plt.title('Predicted Mask') plt.subplot(1, 3, 3) true_mask = mask_to_image(np.asarray(true_mask), mask_values) true_mask.putpalette([ 0, 0, 0, # Black background 255, 255, 255, # Class 1 0, 0, 255, # Class 2 0, 255, 0, # Class 3 ]) plt.imshow(true_mask) plt.axis('off') plt.title('True Mask') plt.show() plt.savefig('res_comparsion') if visualize: logging.info(f'Visualizing results for image {filename}, close to continue...') plot_img_and_mask(img, mask) def solve(model_path, input_file, output_file=None, visualize=False, no_save=False, mask_threshold=0.5, scale=0.5, bilinear=False, num_classes=2): logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') net = UResnet34(n_channels=3, n_classes=num_classes, bilinear=bilinear) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Loading model {model_path}') logging.info(f'Using device {device}') net.to(device=device) state_dict = torch.load(model_path, map_location=device) mask_values = state_dict.pop('mask_values', [0, 1]) net.load_state_dict(state_dict, strict=False) logging.info('Model loaded!') filename = input_file logging.info(f'Predicting image {filename} ...') img = Image.open(filename) true_mask = Image.open(filename.replace('image_1', 'mask')) mask = predict_img(net=net, full_img=img, scale_factor=scale, out_threshold=mask_threshold, device=device) if not no_save: out_filename = output_file result = mask_to_image(mask, mask_values) result.putpalette([ 0, 0, 0, # Black background 255, 255, 255, # Class 1 0, 0, 255, # Class 2 0, 255, 0, # Class 3 ]) result.save(out_filename) logging.info(f'Mask saved to {out_filename}') return result if __name__ == '__main__': args = get_args() MainSolve(model_path=args.model, input_files=args.input, output_files=args.output, visualize=args.viz, no_save=args.no_save, mask_threshold=args.mask_threshold, scale=args.scale, bilinear=args.bilinear, num_classes=args.classes)