|
|
import argparse
|
|
|
import os
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
from tqdm import tqdm
|
|
|
import torchvision
|
|
|
import torchvision.transforms as transforms
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
from sklearn.metrics import confusion_matrix as sklearn_cm
|
|
|
import pandas as pd
|
|
|
import pydicom
|
|
|
from PIL import Image
|
|
|
import random
|
|
|
|
|
|
# 设备配置
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
print('Home device: {}'.format(device))
|
|
|
|
|
|
|
|
|
# 早停类
|
|
|
class EarlyStopping:
|
|
|
def __init__(self, patience=20, verbose=False, delta=0.1):
|
|
|
self.patience = patience
|
|
|
self.verbose = verbose
|
|
|
self.counter = 0
|
|
|
self.best_score = None
|
|
|
self.early_stop = False
|
|
|
self.delta = delta
|
|
|
self.val_ba_min = 0.0
|
|
|
|
|
|
def __call__(self, val_ba, model):
|
|
|
score = val_ba
|
|
|
if self.best_score is None:
|
|
|
self.best_score = score
|
|
|
self.val_ba_min = score
|
|
|
self.save_checkpoint(val_ba, model)
|
|
|
elif score < self.best_score + self.delta:
|
|
|
self.counter += 1
|
|
|
if self.verbose:
|
|
|
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
|
|
if self.counter >= self.patience:
|
|
|
self.early_stop = True
|
|
|
else:
|
|
|
self.best_score = score
|
|
|
self.save_checkpoint(val_ba, model)
|
|
|
self.counter = 0
|
|
|
|
|
|
def save_checkpoint(self, val_ba, model):
|
|
|
if self.verbose:
|
|
|
print(f'Validation BA increased ({self.val_ba_min:.2f} --> {val_ba:.2f}). Saving model ...')
|
|
|
torch.save(model.state_dict(), 'swav_best_model.pth')
|
|
|
self.val_ba_min = val_ba
|
|
|
|
|
|
|
|
|
# SwaV模型
|
|
|
class SwaV(nn.Module):
|
|
|
def __init__(self, backbone, n_prototypes=400):
|
|
|
super().__init__()
|
|
|
self.backbone = backbone
|
|
|
self.projection_head = nn.Sequential(
|
|
|
nn.Linear(512, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(512, 128)
|
|
|
)
|
|
|
self.prototypes = nn.Linear(128, n_prototypes, bias=False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
features = self.backbone(x)
|
|
|
features = features.flatten(start_dim=1)
|
|
|
z = self.projection_head(features)
|
|
|
z = F.normalize(z, dim=1, p=2)
|
|
|
p = self.prototypes(z)
|
|
|
return p
|
|
|
|
|
|
|
|
|
# SwaV损失函数(手动实现)
|
|
|
class SwaVLoss(nn.Module):
|
|
|
def __init__(self, temperature=0.1, sinkhorn_iterations=3):
|
|
|
super().__init__()
|
|
|
self.temperature = temperature
|
|
|
self.sinkhorn_iterations = sinkhorn_iterations
|
|
|
|
|
|
def sinkhorn(self, Q):
|
|
|
Q = torch.exp(Q / self.temperature)
|
|
|
Q = Q / Q.sum()
|
|
|
K, B = Q.shape
|
|
|
|
|
|
for _ in range(self.sinkhorn_iterations):
|
|
|
Q = Q / Q.sum(1, keepdim=True)
|
|
|
Q = Q / Q.sum(0, keepdim=True)
|
|
|
return Q * B
|
|
|
|
|
|
def forward(self, high_res_outputs, low_res_outputs):
|
|
|
outputs = torch.cat([high_res_outputs, low_res_outputs], dim=0)
|
|
|
B = outputs.shape[0] // 2
|
|
|
sim = torch.matmul(outputs, outputs.T) / self.temperature
|
|
|
sim = sim - torch.eye(B * 2, device=sim.device) * 1e12
|
|
|
labels = torch.cat([torch.arange(B, 2 * B), torch.arange(B)], dim=0).to(sim.device)
|
|
|
loss = F.cross_entropy(sim, labels)
|
|
|
return loss
|
|
|
|
|
|
|
|
|
# 平衡准确率计算
|
|
|
def calculate_balanced_accuracy(output, target):
|
|
|
cm = sklearn_cm(target, output)
|
|
|
n_class = cm.shape[0]
|
|
|
recalls = []
|
|
|
for i in range(n_class):
|
|
|
class_total = np.sum(cm[i])
|
|
|
if class_total == 0:
|
|
|
recalls.append(0.0)
|
|
|
else:
|
|
|
recall = cm[i, i] / class_total
|
|
|
recalls.append(recall)
|
|
|
balanced_accuracy = np.mean(np.array(recalls))
|
|
|
return balanced_accuracy * 100
|
|
|
|
|
|
|
|
|
# 通用的数据加载函数
|
|
|
def load_dicom_as_pil(path):
|
|
|
dicom = pydicom.dcmread(path)
|
|
|
img_array = dicom.pixel_array
|
|
|
img_array = img_array.astype(np.float32)
|
|
|
center, width = 40, 400
|
|
|
img_array = np.clip(img_array, center - width // 2, center + width // 2)
|
|
|
img_array = (img_array - (center - width // 2)) / width
|
|
|
img_array = np.clip(img_array, 0.0, 1.0)
|
|
|
img_array = (img_array * 255).astype(np.uint8)
|
|
|
img = Image.fromarray(img_array).convert('RGB')
|
|
|
return img
|
|
|
|
|
|
|
|
|
# 自定义SwaV视图增强(修复transform调用方式)
|
|
|
class SwAVTransform:
|
|
|
def __init__(self, crop_size=128):
|
|
|
self.crop_size = crop_size
|
|
|
# 定义后续增强操作(不包含裁剪)
|
|
|
self.post_transform = transforms.Compose([
|
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
|
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
|
|
transforms.RandomGrayscale(p=0.2),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
|
|
])
|
|
|
|
|
|
def __call__(self, img):
|
|
|
views = []
|
|
|
# 2个高分辨率视图(scale≥0.14)
|
|
|
for _ in range(2):
|
|
|
crop = transforms.RandomResizedCrop(self.crop_size, scale=(0.14, 1.0))(img)
|
|
|
views.append(self.post_transform(crop))
|
|
|
# 6个低分辨率视图(scale<0.14)
|
|
|
for _ in range(6):
|
|
|
crop = transforms.RandomResizedCrop(self.crop_size // 2, scale=(0.05, 0.14))(img)
|
|
|
crop = transforms.Resize(self.crop_size)(crop)
|
|
|
views.append(self.post_transform(crop))
|
|
|
return views
|
|
|
|
|
|
|
|
|
# 无标签数据集
|
|
|
class RSNAUnlabeledDataset(Dataset):
|
|
|
def __init__(self, data_dir, transform=None):
|
|
|
self.data_dir = data_dir
|
|
|
self.transform = transform
|
|
|
self.image_files = [f for f in os.listdir(data_dir) if f.endswith('.dcm')]
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.image_files)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
img_name = self.image_files[idx]
|
|
|
img_path = os.path.join(self.data_dir, img_name)
|
|
|
img = load_dicom_as_pil(img_path)
|
|
|
if self.transform:
|
|
|
img = self.transform(img) # 返回8个视图的列表
|
|
|
return img, 0, idx
|
|
|
|
|
|
|
|
|
# 有标签数据集
|
|
|
class RSNALabeledDataset(Dataset):
|
|
|
def __init__(self, data_dir, label_csv, transform=None):
|
|
|
self.data_dir = data_dir
|
|
|
self.transform = transform
|
|
|
self.labels_df = pd.read_csv(label_csv)
|
|
|
self.labels_df['patientId'] = self.labels_df['patientId'].astype(str)
|
|
|
self.image_ids = self.labels_df['patientId'].tolist()
|
|
|
self.labels = self.labels_df['Target'].tolist()
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.image_ids)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
img_id = self.image_ids[idx]
|
|
|
img_path = os.path.join(self.data_dir, f"{img_id}.dcm")
|
|
|
img = load_dicom_as_pil(img_path)
|
|
|
if self.transform:
|
|
|
img = self.transform(img)
|
|
|
label = self.labels[idx]
|
|
|
return img, label, idx
|
|
|
|
|
|
|
|
|
# 数据加载器函数
|
|
|
def get_rsna_dataloaders(args):
|
|
|
data_dir = os.path.join(args.root, args.dataset_name, "stage_2_train_images")
|
|
|
|
|
|
# 无标签数据增强
|
|
|
swav_transform = SwAVTransform(crop_size=args.crop_size)
|
|
|
unlabeled_dataset = RSNAUnlabeledDataset(data_dir=data_dir, transform=swav_transform)
|
|
|
|
|
|
# 有标签数据增强
|
|
|
dataset_transform = transforms.Compose([
|
|
|
transforms.Resize(args.crop_size),
|
|
|
transforms.CenterCrop(args.crop_size),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.5], std=[0.5])
|
|
|
])
|
|
|
|
|
|
label_csv_path = os.path.join(args.root, args.dataset_name, "stage_2_train_labels.csv")
|
|
|
full_labels = pd.read_csv(label_csv_path)
|
|
|
full_labels['patientId'] = full_labels['patientId'].astype(str)
|
|
|
available_ids = {os.path.splitext(f)[0] for f in os.listdir(data_dir)}
|
|
|
full_labels = full_labels[full_labels['patientId'].isin(available_ids)]
|
|
|
|
|
|
train_labels = full_labels.sample(n=min(400, len(full_labels)), random_state=args.seed)
|
|
|
remaining_labels = full_labels.drop(train_labels.index)
|
|
|
val_labels = remaining_labels.sample(n=min(400, len(remaining_labels)), random_state=args.seed)
|
|
|
test_labels = remaining_labels.drop(val_labels.index)
|
|
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
train_label_path = os.path.join(args.save_dir, f"train_labels_seed{args.seed}.csv")
|
|
|
val_label_path = os.path.join(args.save_dir, f"val_labels_seed{args.seed}.csv")
|
|
|
test_label_path = os.path.join(args.save_dir, f"test_labels_seed{args.seed}.csv")
|
|
|
|
|
|
train_labels.to_csv(train_label_path, index=False)
|
|
|
val_labels.to_csv(val_label_path, index=False)
|
|
|
test_labels.to_csv(test_label_path, index=False)
|
|
|
|
|
|
labeled_train_dataset = RSNALabeledDataset(data_dir, train_label_path, transform=dataset_transform)
|
|
|
val_dataset = RSNALabeledDataset(data_dir, val_label_path, transform=dataset_transform)
|
|
|
test_dataset = RSNALabeledDataset(data_dir, test_label_path, transform=dataset_transform)
|
|
|
|
|
|
# 修正collate_fn:处理视图列表
|
|
|
def swav_collate_fn(batch):
|
|
|
batch_views = []
|
|
|
for imgs, _, _ in batch:
|
|
|
batch_views.extend(imgs) # 展平每个样本的8个视图
|
|
|
# 拆分高/低分辨率视图(前2*B个是高分辨率,后6*B个是低分辨率)
|
|
|
B = len(batch)
|
|
|
high_res_views = torch.stack(batch_views[:2 * B])
|
|
|
low_res_views = torch.stack(batch_views[2 * B:])
|
|
|
return (high_res_views, low_res_views), torch.zeros(B), torch.zeros(B)
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
unlabeled_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=swav_collate_fn,
|
|
|
num_workers=args.num_workers, pin_memory=True, drop_last=True
|
|
|
)
|
|
|
labeled_train_loader = DataLoader(
|
|
|
labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
|
|
|
num_workers=args.num_workers, pin_memory=True
|
|
|
)
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset, batch_size=args.batch_size, shuffle=False,
|
|
|
num_workers=args.num_workers, pin_memory=True
|
|
|
)
|
|
|
test_loader = DataLoader(
|
|
|
test_dataset, batch_size=args.batch_size, shuffle=False,
|
|
|
num_workers=args.num_workers, pin_memory=True
|
|
|
)
|
|
|
|
|
|
print(f"Dataset split check:")
|
|
|
print(f"- Unlabeled train: {len(unlabeled_dataset)} samples")
|
|
|
print(f"- Labeled train: {len(labeled_train_dataset)} samples")
|
|
|
print(f"- Validation: {len(val_dataset)} samples")
|
|
|
print(f"- Test: {len(test_dataset)} samples")
|
|
|
|
|
|
return train_loader, labeled_train_loader, val_loader, test_loader
|
|
|
|
|
|
|
|
|
# Train函数
|
|
|
def train(args, net, data_loader, train_optimizer, scheduler, criterion, epoch):
|
|
|
net.train()
|
|
|
total_loss = 0.0
|
|
|
total_num = 0
|
|
|
train_bar = tqdm(data_loader, desc=f'Train Epoch [{epoch}/{args.epochs}]')
|
|
|
|
|
|
for (high_res_views, low_res_views), _, _ in train_bar:
|
|
|
high_res_views = high_res_views.to(args.device)
|
|
|
low_res_views = low_res_views.to(args.device)
|
|
|
|
|
|
train_optimizer.zero_grad()
|
|
|
high_res_outputs = net(high_res_views)
|
|
|
low_res_outputs = net(low_res_views)
|
|
|
|
|
|
loss = criterion(high_res_outputs, low_res_outputs)
|
|
|
|
|
|
loss.backward()
|
|
|
train_optimizer.step()
|
|
|
scheduler.step()
|
|
|
|
|
|
total_num += len(high_res_views)
|
|
|
total_loss += loss.item() * len(high_res_views)
|
|
|
train_bar.set_description(f'Train Epoch [{epoch}/{args.epochs}], Loss: {total_loss / total_num:.4f}')
|
|
|
|
|
|
return total_loss / total_num
|
|
|
|
|
|
|
|
|
# 分类器微调与评估
|
|
|
def fine_tune_classifier(net, labeled_loader, val_loader, test_loader, args):
|
|
|
net.eval()
|
|
|
|
|
|
class FeatureExtractor(nn.Module):
|
|
|
def __init__(self, backbone):
|
|
|
super().__init__()
|
|
|
self.backbone = backbone
|
|
|
|
|
|
def forward(self, x):
|
|
|
features = self.backbone(x)
|
|
|
features = features.flatten(start_dim=1)
|
|
|
return features
|
|
|
|
|
|
feature_extractor = FeatureExtractor(net.backbone).to(args.device)
|
|
|
|
|
|
def extract_features(dataloader):
|
|
|
features, labels = [], []
|
|
|
with torch.no_grad():
|
|
|
for data, target, _ in dataloader:
|
|
|
data = data.to(args.device)
|
|
|
feat = feature_extractor(data)
|
|
|
features.append(feat.cpu().numpy())
|
|
|
labels.append(target.numpy())
|
|
|
return np.vstack(features), np.hstack(labels)
|
|
|
|
|
|
train_features, train_labels = extract_features(labeled_loader)
|
|
|
val_features, val_labels = extract_features(val_loader)
|
|
|
test_features, test_labels = extract_features(test_loader)
|
|
|
|
|
|
best_reg, best_val_ba, best_test_ba = 0.01, 0.0, 0.0
|
|
|
for _ in range(10):
|
|
|
reg = np.power(10, random.uniform(-3, 0))
|
|
|
clf = LogisticRegression(random_state=args.seed, C=reg, max_iter=1000, solver='saga')
|
|
|
clf.fit(train_features, train_labels)
|
|
|
|
|
|
val_pred = clf.predict(val_features)
|
|
|
val_ba = calculate_balanced_accuracy(val_pred, val_labels)
|
|
|
|
|
|
if val_ba > best_val_ba:
|
|
|
best_val_ba = val_ba
|
|
|
best_reg = reg
|
|
|
test_pred = clf.predict(test_features)
|
|
|
best_test_ba = calculate_balanced_accuracy(test_pred, test_labels)
|
|
|
|
|
|
return best_val_ba, best_test_ba, best_reg
|
|
|
|
|
|
|
|
|
# 主函数
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser(description='SwaV for RSNA Medical Image Classification')
|
|
|
parser.add_argument('--root', type=str, default='/data', help='Path to data directory')
|
|
|
parser.add_argument('--dataset_name', default='rsna', type=str, help='Dataset name')
|
|
|
parser.add_argument('--crop_size', default=128, type=int, help='Crop size')
|
|
|
parser.add_argument('--n_prototypes', default=400, type=int, help='Number of prototypes')
|
|
|
parser.add_argument('--batch_size', default=8, type=int, help='Batch size')
|
|
|
parser.add_argument('--lr', default=0.001, type=float, help='Initial learning rate')
|
|
|
parser.add_argument('--wd', default=1e-6, type=float, help='Weight decay')
|
|
|
parser.add_argument('--epochs', default=200, type=int, help='Max epochs')
|
|
|
parser.add_argument('--seed', default=0, type=int, help='Random seed')
|
|
|
parser.add_argument('--num_workers', default=4, type=int, help='Number of data loading workers')
|
|
|
parser.add_argument('--save_dir', type=str, default='./results/rsna_seed0', help='Path to save results')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
args.device = device
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
|
|
np.random.seed(args.seed)
|
|
|
torch.manual_seed(args.seed)
|
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
print("=" * 50)
|
|
|
print(f"Loading {args.dataset_name} dataset")
|
|
|
train_loader, labeled_train_loader, val_loader, test_loader = get_rsna_dataloaders(args)
|
|
|
print("=" * 50)
|
|
|
# 初始化模型
|
|
|
resnet = torchvision.models.resnet18(weights=None)
|
|
|
backbone = nn.Sequential(*list(resnet.children())[:-1])
|
|
|
model = SwaV(backbone, n_prototypes=args.n_prototypes)
|
|
|
model.to(args.device)
|
|
|
|
|
|
# 初始化损失函数和优化器
|
|
|
criterion = SwaVLoss(temperature=0.1)
|
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * len(train_loader))
|
|
|
early_stopping = EarlyStopping(patience=20, verbose=True)
|
|
|
|
|
|
results_df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_ba', 'test_ba'])
|
|
|
best_val_ba = 0.0
|
|
|
|
|
|
for epoch in range(1, args.epochs + 1):
|
|
|
train_loss = train(args, model, train_loader, optimizer, scheduler, criterion, epoch)
|
|
|
|
|
|
if epoch % 5 == 0:
|
|
|
val_ba, test_ba, _ = fine_tune_classifier(model, labeled_train_loader, val_loader, test_loader, args)
|
|
|
print(f"Epoch {epoch} Evaluation -> Val BA: {val_ba:.2f}%, Test BA: {test_ba:.2f}%")
|
|
|
|
|
|
if val_ba > best_val_ba:
|
|
|
best_val_ba = val_ba
|
|
|
torch.save(model.state_dict(), os.path.join(args.save_dir, 'best_model.pth'))
|
|
|
|
|
|
results_df = pd.concat([results_df, pd.DataFrame({
|
|
|
'epoch': [epoch], 'train_loss': [train_loss], 'val_ba': [val_ba], 'test_ba': [test_ba]
|
|
|
})], ignore_index=True)
|
|
|
results_df.to_csv(os.path.join(args.save_dir, 'training_results.csv'), index=False)
|
|
|
|
|
|
early_stopping(val_ba, model)
|
|
|
if early_stopping.early_stop:
|
|
|
print(f"Early stopping at epoch {epoch}")
|
|
|
break
|
|
|
else:
|
|
|
print(f"Epoch {epoch} finished. Train Loss: {train_loss:.4f}")
|
|
|
|
|
|
print("\nTraining completed!")
|
|
|
print(f"Best Validation BA: {best_val_ba:.2f}%") |