You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
3.6 KiB

import torch
from torch.utils.data import DataLoader
import argparse
import os
from data.dataset import LungXrayDataset
from models.autoencoder import Autoencoder
from models.simplecnn import SimpleCNN
from utils import load_config
from torchvision import transforms
from PIL import Image
import warnings
# 禁止警告输出
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Inference')
# 模型路径
parser.add_argument('--autoencoder_path', type=str, required=True,
help='Path to the trained autoencoder model')
parser.add_argument('--cnn_path', type=str, required=True,
help='Path to the trained CNN model')
# 数据路径
parser.add_argument('--image_path', type=str, required=True,
help='Path to the input X-ray image')
# 配置路径
parser.add_argument('--config', type=str, default='config/config.yaml',
help='Path to config file')
# 设备选项
parser.add_argument('--device', type=str, default='cuda',
choices=['cuda', 'cpu'],
help='Device to use for inference (cuda or cpu)')
return parser.parse_args()
def load_models(args, config, device):
"""加载自编码器和 CNN 模型"""
autoencoder = Autoencoder().to(device)
cnn_model = SimpleCNN().to(device)
autoencoder.load_state_dict(torch.load(args.autoencoder_path, map_location=device))
cnn_model.load_state_dict(torch.load(args.cnn_path, map_location=device))
autoencoder.eval()
cnn_model.eval()
return autoencoder, cnn_model
def preprocess_image(image_path, config):
"""预处理输入图像"""
img = Image.open(image_path).convert('L') # 转换为灰度图像
preprocess_config = config['data']['preprocess']
transform_list = [
transforms.Resize(preprocess_config['resize_size']),
transforms.ToTensor(),
]
if preprocess_config.get('normalize', False):
transform_list.append(transforms.Normalize(
mean=preprocess_config['mean'],
std=preprocess_config['std']
))
transform = transforms.Compose(transform_list)
img_tensor = transform(img).unsqueeze(0) # 添加 batch 维度
return img_tensor
def main():
args = parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
config = load_config(args.config)
# 加载模型
autoencoder, cnn_model = load_models(args, config, device)
# 预处理图像
input_tensor = preprocess_image(args.image_path, config).to(device)
with torch.no_grad():
# 通过自编码器去噪
denoised_image = autoencoder(input_tensor)
# 通过 CNN 进行分类
output = cnn_model(denoised_image)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
# 定义类别标签 (需要与你的训练数据集一致)
class_names = ['Covid', 'Normal', 'Viral Pneumonia'] # 示例类别
# 将概率转换为百分比
probabilities_percentage = probabilities.cpu().numpy()[0] * 100 # 转换为百分比
# 格式化输出
print(f"Prediction Result:")
print(f"--------------------------------------")
print(f"Predicted Class: {class_names[predicted_class]}")
print(f"Probabilities: Covid: {probabilities_percentage[0]:.2f}%, Normal: {probabilities_percentage[1]:.2f}%, Viral Pneumonia: {probabilities_percentage[2]:.2f}%")
print(f"--------------------------------------")
if __name__ == "__main__":
main()