import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torchvision.models import resnet34 from torch.utils.data import random_split from torch.utils.data import WeightedRandomSampler from sklearn.metrics import precision_score, recall_score, f1_score from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from tqdm import tqdm import time import random import os import shutil from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pprint import timm import pytorch_lightning as pl import logging import albumentations as A from albumentations.pytorch import ToTensorV2 import seaborn as sns h = { "num_epochs": 6, "batch_size": 16, "image_size": 224, "fc1_size": 512, "lr": 0.001, "model": "resnet34_custom", "scheduler": "CosineAnnealingLR10", "balance": True, "early_stopping_patience": float("inf"), "use_best_checkpoint": False } class CustomImageFolder(torch.utils.data.Dataset): def __init__(self, root, transform=None, is_valid_file=None): self.dataset = datasets.ImageFolder(root, is_valid_file=is_valid_file) self.transform = transform self.targets = self.dataset.targets def __getitem__(self, index): image, label = self.dataset[index] if self.transform: image = self.transform(image=np.array(image))["image"] / 255.0 return image, label def __len__(self): return len(self.dataset) class PneumoniaDataModule(pl.LightningDataModule): def __init__(self, h, data_dir): super().__init__() self.h = h self.data_dir = data_dir def setup(self, stage=None): data_transforms_train_alb = A.Compose([ A.Rotate(limit=20), A.HorizontalFlip(p=0.5), A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=1), A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0, rotate_limit=0, p=0.5), A.Perspective(scale=(0.05, 0.15), keep_size=True, p=0.5), A.Resize(height=h["image_size"], width=h["image_size"]), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) data_transforms_val_alb = A.Compose([ A.Resize(self.h["image_size"], self.h["image_size"]), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ]) val_split = 0.2 train_filenames, val_filenames = self._split_file_names(self.data_dir+"train/", val_split) # Load the datasets self.train_dataset = CustomImageFolder(self.data_dir+"train/", transform=data_transforms_train_alb, is_valid_file=lambda x: x in train_filenames) self.val_dataset = CustomImageFolder(self.data_dir+"train/", transform=data_transforms_val_alb, is_valid_file=lambda x: x in val_filenames) self.test_dataset = CustomImageFolder(self.data_dir+"test/", transform=data_transforms_val_alb, is_valid_file=lambda x: self._is_image_file(x)) def train_dataloader(self): if self.h["balance"]: sampler = self._create_weighted_sampler(self.train_dataset) return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.h["batch_size"], sampler=sampler, num_workers=0) else: return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.h["batch_size"], shuffle=True, num_workers=0) def val_dataloader(self): return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.h["batch_size"], num_workers=0) def test_dataloader(self): return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.h["batch_size"], num_workers=0) def _extract_patient_ids(self, filename): patient_id = filename.split('_')[0].replace("person", "") return patient_id def _is_image_file(self, file_path): return file_path.lower().endswith((".jpeg", ".jpg", ".png")) def _split_file_names(self, input_folder, val_split_perc): # Pneumonia files contain patient id, so we group split them by patient to avoid data leakage pneumonia_patient_ids = set([self._extract_patient_ids(fn) for fn in os.listdir(os.path.join(input_folder, 'PNEUMONIA'))]) pneumonia_val_patient_ids = random.sample(list(pneumonia_patient_ids), int(val_split_perc * len(pneumonia_patient_ids))) pneumonia_val_filenames = [] pneumonia_train_filenames = [] for filename in os.listdir(os.path.join(input_folder, 'PNEUMONIA')): if self._is_image_file(filename): patient_id = self._extract_patient_ids(filename) if patient_id in pneumonia_val_patient_ids: pneumonia_val_filenames.append(os.path.join(input_folder, 'PNEUMONIA', filename)) else: pneumonia_train_filenames.append(os.path.join(input_folder, 'PNEUMONIA', filename)) # Normal (by file, no patient information in file names) normal_filenames = [os.path.join(input_folder, 'NORMAL', fn) for fn in os.listdir(os.path.join(input_folder, 'NORMAL'))] normal_filenames = [filename for filename in normal_filenames if self._is_image_file(filename)] normal_val_filenames = random.sample(normal_filenames, int(val_split_perc * len(normal_filenames))) normal_train_filenames = list(set(normal_filenames)-set(normal_val_filenames)) train_filenames = pneumonia_train_filenames + normal_train_filenames val_filenames = pneumonia_val_filenames + normal_val_filenames return train_filenames, val_filenames def _create_weighted_sampler(self, dataset): targets = dataset.targets class_counts = np.bincount(targets) class_weights = 1.0 / class_counts weights = [class_weights[label] for label in targets] sampler = WeightedRandomSampler(weights, len(weights)) return sampler #带seblock的Resnet34 class SEBlock(nn.Module): def __init__(self, channel, reduction=16): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y class SEBasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, downsample=None): super(SEBasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.se = SEBlock(planes) self.downsample = downsample def forward(self, x): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out = self.se(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out def make_layer(block, in_planes, planes, blocks, stride=1): downsample = None if stride != 1 or in_planes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(in_planes, planes, stride, downsample)] for _ in range(1, blocks): layers.append(block(planes * block.expansion, planes)) return nn.Sequential(*layers) class CustomResNet(nn.Module): def __init__(self, block, layers, num_classes=2): super(CustomResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = make_layer(block, 64, 64, layers[0]) self.layer2 = make_layer(block, 64, 128, layers[1], stride=2) self.layer3 = make_layer(block, 128, 256, layers[2], stride=2) self.layer4 = make_layer(block, 256, 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x #替换损失函数 class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum() class PneumoniaModel(pl.LightningModule): def __init__(self, h): super().__init__() self.h = h self.model = self._create_model() self.criterion = FocalLoss(alpha=1, gamma=2) self.test_outputs = [] def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): inputs, labels = batch outputs = self(inputs) loss = self.criterion(outputs, labels) self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): inputs, labels = batch outputs = self(inputs) loss = self.criterion(outputs, labels) acc = (outputs.argmax(dim=1) == labels).float().mean() metrics = {"val_loss": loss, "val_acc": acc} self.log_dict(metrics, on_epoch=True, on_step=True, prog_bar=True) return metrics def on_test_epoch_start(self): self.test_outputs = [] def test_step(self, batch, batch_idx): inputs, labels = batch outputs = self(inputs) loss = self.criterion(outputs, labels) acc = (outputs.argmax(dim=1) == labels).float().mean() preds = torch.argmax(outputs, dim=1) self.test_outputs.append({"test_loss": loss, "test_acc": acc, "preds": preds, "labels": labels}) return {"test_loss": loss, "test_acc": acc, "preds": preds, "labels": labels} def on_test_epoch_end(self): test_loss_mean = torch.stack([x["test_loss"] for x in self.test_outputs]).mean() test_acc_mean = torch.stack([x["test_acc"] for x in self.test_outputs]).mean() self.test_predicted_labels = torch.cat([x["preds"] for x in self.test_outputs], dim=0).cpu().numpy() self.test_true_labels = torch.cat([x["labels"] for x in self.test_outputs], dim=0).cpu().numpy() #Todo: remove f1 calculation from here f1 = f1_score(self.test_true_labels, self.test_predicted_labels) self.test_f1 = f1 self.test_acc = test_acc_mean.cpu().numpy() #Todo - fix it def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.h["lr"]) scheduler_dic = self._configure_scheduler(optimizer) if (scheduler_dic["scheduler"]): return { "optimizer": optimizer, "lr_scheduler": scheduler_dic } else: return optimizer def _configure_scheduler(self, optimizer): scheduler_name = self.h["scheduler"] lr = self.h["lr"] if (scheduler_name==""): return { "scheduler": None } if (scheduler_name=="CosineAnnealingLR10"): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=h["num_epochs"], eta_min=lr*0.1) #*len(train_loader) if "step" return { "scheduler": scheduler, "interval": "epoch" } if (scheduler_name=="ReduceLROnPlateau5"): scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) return { "scheduler": scheduler, "interval": "epoch", "monitor": "val_loss", "strict": True } print ("Error. Unknown scheduler name '{scheduler_name}'") return None def _create_model(self): if (self.h["model"]=="efficientnetv2"): return timm.create_model("tf_efficientnetv2_b0", pretrained=True, num_classes=2) if (self.h["model"]=="fc"): return nn.Sequential( nn.Flatten(), nn.Linear(3 * self.h["image_size"] * self.h["image_size"], self.h["fc1_size"]), nn.ReLU(), nn.Linear(self.h["fc1_size"], 2) ) if (self.h["model"]=="cnn"): return nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Flatten(), nn.Dropout(0.25), nn.Linear(64 * (self.h["image_size"] // 8) * (self.h["image_size"] // 8), 512), nn.ReLU(), nn.Dropout(0.25), nn.Linear(512, 2) ) if (self.h["model"]=="resnet34_custom"): model = CustomResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=2) return model class InfoPrinterCallback(pl.Callback): def __init__(self): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): print(f"CPU cores: {os.cpu_count()}, Device: {device}, GPU: {torch.cuda.get_device_name(0)}") else: print(f"CPU cores: {os.cpu_count()}, Device: {device}") # Print hyperparameters for records print("Hyperparameters:") pprint.pprint(h, indent=4) def setup(self, trainer, pl_module, stage): self.start_time = time.time() def on_validation_epoch_end(self, trainer, pl_module): # Skip the sanity check if trainer.sanity_checking: return epoch = trainer.current_epoch total_epochs = trainer.max_epochs elapsed_time = time.time() - self.start_time avg_time_per_epoch = elapsed_time / (epoch + 1) avg_time_per_epoch_min, avg_time_per_epoch_sec = divmod(avg_time_per_epoch, 60) remaining_epochs = total_epochs - epoch - 1 remaining_time = remaining_epochs * avg_time_per_epoch remaining_time_min, remaining_time_sec = divmod(remaining_time, 60) print(f"Epoch {epoch + 1}/{total_epochs}: ", end="") if "val_loss" in trainer.callback_metrics: validation_loss = trainer.callback_metrics["val_loss"].cpu().numpy() #self.validation_losses.append(validation_loss) print(f"Validation Loss = {validation_loss:.4f}", end="") else: print(f"Validation Loss not available", end="") if "train_loss_epoch" in trainer.logged_metrics: train_loss = trainer.logged_metrics["train_loss_epoch"].cpu().numpy() print(f", Train Loss = {train_loss:.4f}", end="") else: print(f", Train Loss not available", end="") print(f", Epoch Time: {avg_time_per_epoch_min:.0f}m {avg_time_per_epoch_sec:02.0f}s, Remaining Time: {remaining_time_min:.0f}m {remaining_time_sec:02.0f}s") def plot_losses(self): plt.style.use("seaborn-v0_8-whitegrid") plt.rcParams.update({ "font.family": "serif", "font.size": 10, "axes.titlesize": 11, "axes.labelsize": 10, "xtick.labelsize": 9, "ytick.labelsize": 9, "legend.fontsize": 9, "figure.figsize": (4, 3), "figure.dpi": 150, "lines.linewidth": 2, "lines.markersize": 5 }) plt.figure() plt.plot(self.training_losses, label="Training Loss", marker='o') plt.plot(self.validation_losses, label="Validation Loss", marker='s') plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend() plt.tight_layout() plt.show() class PlotTestConfusionMatrixCallback(pl.Callback): def on_test_end(self, trainer, pl_module): cm = confusion_matrix(pl_module.test_true_labels, pl_module.test_predicted_labels) class_names = ["Normal", "Pneumonia"] plt.style.use("seaborn-v0_8-white") plt.rcParams.update({ "font.family": "serif", "font.size": 10, "axes.titlesize": 11, "axes.labelsize": 10, "xtick.labelsize": 9, "ytick.labelsize": 9, "legend.fontsize": 9, "figure.figsize": (4, 3), "figure.dpi": 150 }) fig, ax = plt.subplots() sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, cbar=False, linewidths=0.5, linecolor='gray') ax.set_xlabel("Predicted Label") ax.set_ylabel("True Label") ax.set_title("Confusion Matrix") plt.tight_layout() plt.show() class PlotTrainingLogsCallback(pl.Callback): def __init__(self): self.validation_losses = [] self.training_losses = [] def on_train_epoch_end(self, trainer, pl_module): if "train_loss_epoch" in trainer.logged_metrics: train_loss = trainer.logged_metrics["train_loss_epoch"].cpu().numpy() self.training_losses.append(train_loss) def on_validation_epoch_end(self, trainer, pl_module): if trainer.sanity_checking: return if "val_loss" in trainer.callback_metrics: validation_loss = trainer.callback_metrics["val_loss"].cpu().numpy() self.validation_losses.append(validation_loss) def on_fit_end(self, trainer, pl_module): plt.style.use("seaborn-v0_8-whitegrid") plt.rcParams.update({ "font.family": "serif", "font.size": 10, "axes.titlesize": 11, "axes.labelsize": 10, "xtick.labelsize": 9, "ytick.labelsize": 9, "legend.fontsize": 9, "figure.figsize": (4, 3), "figure.dpi": 150, "lines.linewidth": 2, "lines.markersize": 5 }) plt.figure() plt.plot(self.training_losses, label="Training Loss", marker='o') plt.plot(self.validation_losses, label="Validation Loss", marker='s') plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend() plt.tight_layout() plt.show() def check_solution(h, verbose): pneumonia_data = PneumoniaDataModule(h, "D:/DATASET/archive/chest_xray/chest_xray/") pneumonia_model = PneumoniaModel(h) # Callbacks info_printer = InfoPrinterCallback() early_stopping = pl.callbacks.EarlyStopping( monitor="val_loss", patience=h["early_stopping_patience"], verbose=True, ) checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath="model_checkpoints", monitor="val_loss", verbose=True, ) callbacks = [info_printer, early_stopping, checkpoint_callback] if (verbose): callbacks.append(PlotTestConfusionMatrixCallback()) callbacks.append(PlotTrainingLogsCallback()) trainer = pl.Trainer( max_epochs=h["num_epochs"], accelerator="auto", callbacks=callbacks, log_every_n_steps=1, fast_dev_run=False ) trainer.fit(pneumonia_model, datamodule=pneumonia_data) if (h["use_best_checkpoint"]): #Debug lines trainer.test(pneumonia_model, datamodule=pneumonia_data) print(f"Last: F1= {pneumonia_model.test_f1:.4f}, Acc= {pneumonia_model.test_acc:.4f}") best_model_path = checkpoint_callback.best_model_path best_model = PneumoniaModel.load_from_checkpoint(best_model_path, h=h) pneumonia_model = best_model trainer.test(pneumonia_model, datamodule=pneumonia_data) print(f"Best: F1= {pneumonia_model.test_f1:.4f}, Acc= {pneumonia_model.test_acc:.4f}") return pneumonia_model.test_f1, pneumonia_model.test_acc logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) f1_array = np.array([]) accuracy_array = np.array([]) start_time = time.time() repeats = 1 for i in range(repeats): print("===============================================") print(f"Running solution {i+1}/{repeats}") f1, accuracy = check_solution(h, verbose=(i==0)) print(f"F1 = {f1:.2f}, accuracy = {accuracy:.2f} ") f1_array = np.append(f1_array, f1) accuracy_array = np.append(accuracy_array, accuracy) # Calculate elapsed time and remaining time repeat_time = (time.time() - start_time) / repeats repeat_time_min, repeat_time_sec = divmod(repeat_time, 60) # Printing final results print("Results") print(f"F1: {np.mean(f1_array):.1%} (+-{np.std(f1_array):.1%})") print(f"Accuracy: {np.mean(accuracy_array):.1%} (+-{np.std(accuracy_array):.1%})") print(f"Time of one solution: {repeat_time_min:.0f}m {repeat_time_sec:.0f}s") print(f" | {np.mean(f1_array):.1%} (+-{np.std(f1_array):.1%}) | {np.mean(accuracy_array):.1%} (+-{np.std(accuracy_array):.1%}) | {repeat_time_min:.0f}m {repeat_time_sec:.0f}s") # Print hyperparameters for reminding what the final data is for print("Hyperparameters:") pprint.pprint(h, indent=4)