You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

426 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix as sklearn_cm
import pandas as pd
import pydicom
from PIL import Image
import random
# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Home device: {}'.format(device))
# 早停类
class EarlyStopping:
def __init__(self, patience=20, verbose=False, delta=0.1):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.delta = delta
self.val_ba_min = 0.0
def __call__(self, val_ba, model):
score = val_ba
if self.best_score is None:
self.best_score = score
self.val_ba_min = score
self.save_checkpoint(val_ba, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_ba, model)
self.counter = 0
def save_checkpoint(self, val_ba, model):
if self.verbose:
print(f'Validation BA increased ({self.val_ba_min:.2f} --> {val_ba:.2f}). Saving model ...')
torch.save(model.state_dict(), 'swav_best_model.pth')
self.val_ba_min = val_ba
# SwaV模型
class SwaV(nn.Module):
def __init__(self, backbone, n_prototypes=400):
super().__init__()
self.backbone = backbone
self.projection_head = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 128)
)
self.prototypes = nn.Linear(128, n_prototypes, bias=False)
def forward(self, x):
features = self.backbone(x)
features = features.flatten(start_dim=1)
z = self.projection_head(features)
z = F.normalize(z, dim=1, p=2)
p = self.prototypes(z)
return p
# SwaV损失函数手动实现
class SwaVLoss(nn.Module):
def __init__(self, temperature=0.1, sinkhorn_iterations=3):
super().__init__()
self.temperature = temperature
self.sinkhorn_iterations = sinkhorn_iterations
def sinkhorn(self, Q):
Q = torch.exp(Q / self.temperature)
Q = Q / Q.sum()
K, B = Q.shape
for _ in range(self.sinkhorn_iterations):
Q = Q / Q.sum(1, keepdim=True)
Q = Q / Q.sum(0, keepdim=True)
return Q * B
def forward(self, high_res_outputs, low_res_outputs):
outputs = torch.cat([high_res_outputs, low_res_outputs], dim=0)
B = outputs.shape[0] // 2
sim = torch.matmul(outputs, outputs.T) / self.temperature
sim = sim - torch.eye(B * 2, device=sim.device) * 1e12
labels = torch.cat([torch.arange(B, 2 * B), torch.arange(B)], dim=0).to(sim.device)
loss = F.cross_entropy(sim, labels)
return loss
# 平衡准确率计算
def calculate_balanced_accuracy(output, target):
cm = sklearn_cm(target, output)
n_class = cm.shape[0]
recalls = []
for i in range(n_class):
class_total = np.sum(cm[i])
if class_total == 0:
recalls.append(0.0)
else:
recall = cm[i, i] / class_total
recalls.append(recall)
balanced_accuracy = np.mean(np.array(recalls))
return balanced_accuracy * 100
# 通用的数据加载函数
def load_dicom_as_pil(path):
dicom = pydicom.dcmread(path)
img_array = dicom.pixel_array
img_array = img_array.astype(np.float32)
center, width = 40, 400
img_array = np.clip(img_array, center - width // 2, center + width // 2)
img_array = (img_array - (center - width // 2)) / width
img_array = np.clip(img_array, 0.0, 1.0)
img_array = (img_array * 255).astype(np.uint8)
img = Image.fromarray(img_array).convert('RGB')
return img
# 自定义SwaV视图增强修复transform调用方式
class SwAVTransform:
def __init__(self, crop_size=128):
self.crop_size = crop_size
# 定义后续增强操作(不包含裁剪)
self.post_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def __call__(self, img):
views = []
# 2个高分辨率视图scale≥0.14
for _ in range(2):
crop = transforms.RandomResizedCrop(self.crop_size, scale=(0.14, 1.0))(img)
views.append(self.post_transform(crop))
# 6个低分辨率视图scale<0.14
for _ in range(6):
crop = transforms.RandomResizedCrop(self.crop_size // 2, scale=(0.05, 0.14))(img)
crop = transforms.Resize(self.crop_size)(crop)
views.append(self.post_transform(crop))
return views
# 无标签数据集
class RSNAUnlabeledDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.image_files = [f for f in os.listdir(data_dir) if f.endswith('.dcm')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = self.image_files[idx]
img_path = os.path.join(self.data_dir, img_name)
img = load_dicom_as_pil(img_path)
if self.transform:
img = self.transform(img) # 返回8个视图的列表
return img, 0, idx
# 有标签数据集
class RSNALabeledDataset(Dataset):
def __init__(self, data_dir, label_csv, transform=None):
self.data_dir = data_dir
self.transform = transform
self.labels_df = pd.read_csv(label_csv)
self.labels_df['patientId'] = self.labels_df['patientId'].astype(str)
self.image_ids = self.labels_df['patientId'].tolist()
self.labels = self.labels_df['Target'].tolist()
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx):
img_id = self.image_ids[idx]
img_path = os.path.join(self.data_dir, f"{img_id}.dcm")
img = load_dicom_as_pil(img_path)
if self.transform:
img = self.transform(img)
label = self.labels[idx]
return img, label, idx
# 数据加载器函数
def get_rsna_dataloaders(args):
data_dir = os.path.join(args.root, args.dataset_name, "stage_2_train_images")
# 无标签数据增强
swav_transform = SwAVTransform(crop_size=args.crop_size)
unlabeled_dataset = RSNAUnlabeledDataset(data_dir=data_dir, transform=swav_transform)
# 有标签数据增强
dataset_transform = transforms.Compose([
transforms.Resize(args.crop_size),
transforms.CenterCrop(args.crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
label_csv_path = os.path.join(args.root, args.dataset_name, "stage_2_train_labels.csv")
full_labels = pd.read_csv(label_csv_path)
full_labels['patientId'] = full_labels['patientId'].astype(str)
available_ids = {os.path.splitext(f)[0] for f in os.listdir(data_dir)}
full_labels = full_labels[full_labels['patientId'].isin(available_ids)]
train_labels = full_labels.sample(n=min(400, len(full_labels)), random_state=args.seed)
remaining_labels = full_labels.drop(train_labels.index)
val_labels = remaining_labels.sample(n=min(400, len(remaining_labels)), random_state=args.seed)
test_labels = remaining_labels.drop(val_labels.index)
os.makedirs(args.save_dir, exist_ok=True)
train_label_path = os.path.join(args.save_dir, f"train_labels_seed{args.seed}.csv")
val_label_path = os.path.join(args.save_dir, f"val_labels_seed{args.seed}.csv")
test_label_path = os.path.join(args.save_dir, f"test_labels_seed{args.seed}.csv")
train_labels.to_csv(train_label_path, index=False)
val_labels.to_csv(val_label_path, index=False)
test_labels.to_csv(test_label_path, index=False)
labeled_train_dataset = RSNALabeledDataset(data_dir, train_label_path, transform=dataset_transform)
val_dataset = RSNALabeledDataset(data_dir, val_label_path, transform=dataset_transform)
test_dataset = RSNALabeledDataset(data_dir, test_label_path, transform=dataset_transform)
# 修正collate_fn处理视图列表
def swav_collate_fn(batch):
batch_views = []
for imgs, _, _ in batch:
batch_views.extend(imgs) # 展平每个样本的8个视图
# 拆分高/低分辨率视图前2*B个是高分辨率后6*B个是低分辨率
B = len(batch)
high_res_views = torch.stack(batch_views[:2 * B])
low_res_views = torch.stack(batch_views[2 * B:])
return (high_res_views, low_res_views), torch.zeros(B), torch.zeros(B)
train_loader = DataLoader(
unlabeled_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=swav_collate_fn,
num_workers=args.num_workers, pin_memory=True, drop_last=True
)
labeled_train_loader = DataLoader(
labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
print(f"Dataset split check:")
print(f"- Unlabeled train: {len(unlabeled_dataset)} samples")
print(f"- Labeled train: {len(labeled_train_dataset)} samples")
print(f"- Validation: {len(val_dataset)} samples")
print(f"- Test: {len(test_dataset)} samples")
return train_loader, labeled_train_loader, val_loader, test_loader
# Train函数
def train(args, net, data_loader, train_optimizer, scheduler, criterion, epoch):
net.train()
total_loss = 0.0
total_num = 0
train_bar = tqdm(data_loader, desc=f'Train Epoch [{epoch}/{args.epochs}]')
for (high_res_views, low_res_views), _, _ in train_bar:
high_res_views = high_res_views.to(args.device)
low_res_views = low_res_views.to(args.device)
train_optimizer.zero_grad()
high_res_outputs = net(high_res_views)
low_res_outputs = net(low_res_views)
loss = criterion(high_res_outputs, low_res_outputs)
loss.backward()
train_optimizer.step()
scheduler.step()
total_num += len(high_res_views)
total_loss += loss.item() * len(high_res_views)
train_bar.set_description(f'Train Epoch [{epoch}/{args.epochs}], Loss: {total_loss / total_num:.4f}')
return total_loss / total_num
# 分类器微调与评估
def fine_tune_classifier(net, labeled_loader, val_loader, test_loader, args):
net.eval()
class FeatureExtractor(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
def forward(self, x):
features = self.backbone(x)
features = features.flatten(start_dim=1)
return features
feature_extractor = FeatureExtractor(net.backbone).to(args.device)
def extract_features(dataloader):
features, labels = [], []
with torch.no_grad():
for data, target, _ in dataloader:
data = data.to(args.device)
feat = feature_extractor(data)
features.append(feat.cpu().numpy())
labels.append(target.numpy())
return np.vstack(features), np.hstack(labels)
train_features, train_labels = extract_features(labeled_loader)
val_features, val_labels = extract_features(val_loader)
test_features, test_labels = extract_features(test_loader)
best_reg, best_val_ba, best_test_ba = 0.01, 0.0, 0.0
for _ in range(10):
reg = np.power(10, random.uniform(-3, 0))
clf = LogisticRegression(random_state=args.seed, C=reg, max_iter=1000, solver='saga')
clf.fit(train_features, train_labels)
val_pred = clf.predict(val_features)
val_ba = calculate_balanced_accuracy(val_pred, val_labels)
if val_ba > best_val_ba:
best_val_ba = val_ba
best_reg = reg
test_pred = clf.predict(test_features)
best_test_ba = calculate_balanced_accuracy(test_pred, test_labels)
return best_val_ba, best_test_ba, best_reg
# 主函数
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SwaV for RSNA Medical Image Classification')
parser.add_argument('--root', type=str, default='/data', help='Path to data directory')
parser.add_argument('--dataset_name', default='rsna', type=str, help='Dataset name')
parser.add_argument('--crop_size', default=128, type=int, help='Crop size')
parser.add_argument('--n_prototypes', default=400, type=int, help='Number of prototypes')
parser.add_argument('--batch_size', default=8, type=int, help='Batch size')
parser.add_argument('--lr', default=0.001, type=float, help='Initial learning rate')
parser.add_argument('--wd', default=1e-6, type=float, help='Weight decay')
parser.add_argument('--epochs', default=200, type=int, help='Max epochs')
parser.add_argument('--seed', default=0, type=int, help='Random seed')
parser.add_argument('--num_workers', default=4, type=int, help='Number of data loading workers')
parser.add_argument('--save_dir', type=str, default='./results/rsna_seed0', help='Path to save results')
args = parser.parse_args()
args.device = device
os.makedirs(args.save_dir, exist_ok=True)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
print("=" * 50)
print(f"Loading {args.dataset_name} dataset")
train_loader, labeled_train_loader, val_loader, test_loader = get_rsna_dataloaders(args)
print("=" * 50)
# 初始化模型
resnet = torchvision.models.resnet18(weights=None)
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SwaV(backbone, n_prototypes=args.n_prototypes)
model.to(args.device)
# 初始化损失函数和优化器
criterion = SwaVLoss(temperature=0.1)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * len(train_loader))
early_stopping = EarlyStopping(patience=20, verbose=True)
results_df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_ba', 'test_ba'])
best_val_ba = 0.0
for epoch in range(1, args.epochs + 1):
train_loss = train(args, model, train_loader, optimizer, scheduler, criterion, epoch)
if epoch % 5 == 0:
val_ba, test_ba, _ = fine_tune_classifier(model, labeled_train_loader, val_loader, test_loader, args)
print(f"Epoch {epoch} Evaluation -> Val BA: {val_ba:.2f}%, Test BA: {test_ba:.2f}%")
if val_ba > best_val_ba:
best_val_ba = val_ba
torch.save(model.state_dict(), os.path.join(args.save_dir, 'best_model.pth'))
results_df = pd.concat([results_df, pd.DataFrame({
'epoch': [epoch], 'train_loss': [train_loss], 'val_ba': [val_ba], 'test_ba': [test_ba]
})], ignore_index=True)
results_df.to_csv(os.path.join(args.save_dir, 'training_results.csv'), index=False)
early_stopping(val_ba, model)
if early_stopping.early_stop:
print(f"Early stopping at epoch {epoch}")
break
else:
print(f"Epoch {epoch} finished. Train Loss: {train_loss:.4f}")
print("\nTraining completed!")
print(f"Best Validation BA: {best_val_ba:.2f}%")