#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 训练过程可视化脚本 该脚本用于可视化One-Prompt模型的训练过程,包括: 1. 训练/验证损失曲线 2. IoU和Dice指标曲线 3. 学习率变化曲线 4. 分割结果可视化 Usage: python scripts/visualize_training.py --log_dir logs/polyp_val_test_2025_12_16_23_52_30 """ import os import re import argparse import matplotlib.pyplot as plt import numpy as np from pathlib import Path from typing import List, Tuple, Dict import matplotlib matplotlib.use('Agg') # 非GUI后端 # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def parse_log_file(log_path: str) -> Dict[str, List[float]]: """ 解析训练日志文件。 Args: log_path: 日志文件路径 Returns: 包含训练指标的字典 """ metrics = { 'epoch': [], 'train_loss': [], 'val_loss': [], 'iou': [], 'dice': [] } with open(log_path, 'r', encoding='utf-8') as f: for line in f: # 解析训练损失: Train loss: 0.455222487449646|| @ epoch 0. train_match = re.search(r'Train loss: ([\d.]+)\|\| @ epoch (\d+)', line) if train_match: loss = float(train_match.group(1)) epoch = int(train_match.group(2)) if epoch >= len(metrics['train_loss']): metrics['train_loss'].append(loss) metrics['epoch'].append(epoch) # 解析验证指标: Total score: 0.367, IOU: 0.012, DICE: 0.022 || @ epoch 2. val_match = re.search( r'Total score: ([\d.]+), IOU: ([\d.]+), DICE: ([\d.]+) \|\| @ epoch (\d+)', line ) if val_match: val_loss = float(val_match.group(1)) iou = float(val_match.group(2)) dice = float(val_match.group(3)) metrics['val_loss'].append(val_loss) metrics['iou'].append(iou) metrics['dice'].append(dice) return metrics def plot_loss_curves(metrics: Dict[str, List[float]], save_path: str) -> None: """ 绘制训练和验证损失曲线。 Args: metrics: 训练指标字典 save_path: 图像保存路径 """ fig, ax = plt.subplots(figsize=(10, 6)) epochs = metrics['epoch'] train_loss = metrics['train_loss'] # 绘制训练损失 ax.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2) # 如果有验证损失,绘制验证损失 if metrics['val_loss']: # 验证是每隔几个epoch进行的,需要对齐x轴 val_epochs = np.linspace(0, max(epochs), len(metrics['val_loss'])) ax.plot(val_epochs, metrics['val_loss'], 'r--', label='Validation Loss', linewidth=2) ax.set_xlabel('Epoch', fontsize=12) ax.set_ylabel('Loss', fontsize=12) ax.set_title('Training and Validation Loss Curves', fontsize=14) ax.legend(loc='upper right', fontsize=10) ax.grid(True, alpha=0.3) # 添加最佳损失标注 if train_loss: min_idx = np.argmin(train_loss) ax.annotate(f'Best: {train_loss[min_idx]:.4f}', xy=(epochs[min_idx], train_loss[min_idx]), xytext=(epochs[min_idx] + 2, train_loss[min_idx] + 0.1), arrowprops=dict(arrowstyle='->', color='blue'), fontsize=10, color='blue') plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"损失曲线已保存至: {save_path}") def plot_metric_curves(metrics: Dict[str, List[float]], save_path: str) -> None: """ 绘制IoU和Dice指标曲线。 Args: metrics: 训练指标字典 save_path: 图像保存路径 """ if not metrics['iou'] or not metrics['dice']: print("警告: 没有IoU/Dice指标数据") return fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) val_epochs = range(len(metrics['iou'])) # IoU曲线 ax1.plot(val_epochs, metrics['iou'], 'g-o', label='IoU', linewidth=2, markersize=6) ax1.set_xlabel('Validation Step', fontsize=12) ax1.set_ylabel('IoU', fontsize=12) ax1.set_title('Intersection over Union (IoU)', fontsize=14) ax1.grid(True, alpha=0.3) # 标注最佳IoU if metrics['iou']: max_idx = np.argmax(metrics['iou']) max_iou = metrics['iou'][max_idx] ax1.annotate(f'Best: {max_iou:.4f}', xy=(max_idx, max_iou), xytext=(max_idx + 0.5, max_iou + 0.01), arrowprops=dict(arrowstyle='->', color='green'), fontsize=10, color='green') # Dice曲线 ax2.plot(val_epochs, metrics['dice'], 'm-s', label='Dice', linewidth=2, markersize=6) ax2.set_xlabel('Validation Step', fontsize=12) ax2.set_ylabel('Dice Score', fontsize=12) ax2.set_title('Dice Coefficient', fontsize=14) ax2.grid(True, alpha=0.3) # 标注最佳Dice if metrics['dice']: max_idx = np.argmax(metrics['dice']) max_dice = metrics['dice'][max_idx] ax2.annotate(f'Best: {max_dice:.4f}', xy=(max_idx, max_dice), xytext=(max_idx + 0.5, max_dice + 0.01), arrowprops=dict(arrowstyle='->', color='purple'), fontsize=10, color='purple') plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"指标曲线已保存至: {save_path}") def plot_combined_dashboard(metrics: Dict[str, List[float]], save_path: str) -> None: """ 绘制综合训练仪表板。 Args: metrics: 训练指标字典 save_path: 图像保存路径 """ fig = plt.figure(figsize=(16, 10)) # 创建子图布局 gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3) # 1. 训练损失曲线 ax1 = fig.add_subplot(gs[0, 0]) if metrics['train_loss']: ax1.plot(metrics['epoch'], metrics['train_loss'], 'b-', linewidth=2) ax1.fill_between(metrics['epoch'], metrics['train_loss'], alpha=0.3) ax1.set_title('Training Loss', fontsize=12, fontweight='bold') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.grid(True, alpha=0.3) # 2. 验证损失曲线 ax2 = fig.add_subplot(gs[0, 1]) if metrics['val_loss']: ax2.plot(range(len(metrics['val_loss'])), metrics['val_loss'], 'r-', linewidth=2) ax2.fill_between(range(len(metrics['val_loss'])), metrics['val_loss'], alpha=0.3, color='red') ax2.set_title('Validation Loss', fontsize=12, fontweight='bold') ax2.set_xlabel('Validation Step') ax2.set_ylabel('Loss') ax2.grid(True, alpha=0.3) # 3. IoU曲线 ax3 = fig.add_subplot(gs[0, 2]) if metrics['iou']: ax3.plot(range(len(metrics['iou'])), metrics['iou'], 'g-o', linewidth=2, markersize=4) ax3.set_title('IoU Score', fontsize=12, fontweight='bold') ax3.set_xlabel('Validation Step') ax3.set_ylabel('IoU') ax3.grid(True, alpha=0.3) # 4. Dice曲线 ax4 = fig.add_subplot(gs[1, 0]) if metrics['dice']: ax4.plot(range(len(metrics['dice'])), metrics['dice'], 'm-s', linewidth=2, markersize=4) ax4.set_title('Dice Score', fontsize=12, fontweight='bold') ax4.set_xlabel('Validation Step') ax4.set_ylabel('Dice') ax4.grid(True, alpha=0.3) # 5. 损失对比柱状图 ax5 = fig.add_subplot(gs[1, 1]) if metrics['train_loss'] and metrics['val_loss']: x = np.arange(2) values = [np.mean(metrics['train_loss']), np.mean(metrics['val_loss'])] bars = ax5.bar(x, values, color=['blue', 'red'], alpha=0.7) ax5.set_xticks(x) ax5.set_xticklabels(['Avg Train Loss', 'Avg Val Loss']) ax5.set_title('Average Loss Comparison', fontsize=12, fontweight='bold') # 添加数值标签 for bar, val in zip(bars, values): ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.4f}', ha='center', va='bottom', fontsize=10) # 6. 训练统计信息 ax6 = fig.add_subplot(gs[1, 2]) ax6.axis('off') # 计算统计信息 stats_text = "Training Statistics\n" + "="*30 + "\n\n" if metrics['train_loss']: stats_text += f"Total Epochs: {len(metrics['epoch'])}\n" stats_text += f"Final Train Loss: {metrics['train_loss'][-1]:.4f}\n" stats_text += f"Best Train Loss: {min(metrics['train_loss']):.4f}\n" if metrics['val_loss']: stats_text += f"Final Val Loss: {metrics['val_loss'][-1]:.4f}\n" stats_text += f"Best Val Loss: {min(metrics['val_loss']):.4f}\n" if metrics['iou']: stats_text += f"Best IoU: {max(metrics['iou']):.4f}\n" if metrics['dice']: stats_text += f"Best Dice: {max(metrics['dice']):.4f}\n" ax6.text(0.1, 0.5, stats_text, transform=ax6.transAxes, fontsize=11, verticalalignment='center', fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5)) # 添加总标题 fig.suptitle('One-Prompt Medical Image Segmentation - Training Dashboard', fontsize=16, fontweight='bold', y=0.98) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"训练仪表板已保存至: {save_path}") def main(): """主函数。""" parser = argparse.ArgumentParser(description='可视化训练过程') parser.add_argument('--log_dir', type=str, required=True, help='日志目录路径') parser.add_argument('--output_dir', type=str, default=None, help='输出目录') args = parser.parse_args() log_dir = Path(args.log_dir) output_dir = Path(args.output_dir) if args.output_dir else log_dir / 'visualizations' output_dir.mkdir(parents=True, exist_ok=True) # 查找日志文件 log_files = list(log_dir.glob('Log/*.log')) if not log_files: print(f"错误: 在 {log_dir}/Log/ 目录下未找到日志文件") return log_path = log_files[0] print(f"正在解析日志文件: {log_path}") # 解析日志 metrics = parse_log_file(str(log_path)) print(f"解析完成: {len(metrics['epoch'])} 个epoch, " f"{len(metrics['val_loss'])} 次验证") # 生成可视化 plot_loss_curves(metrics, str(output_dir / 'loss_curves.png')) plot_metric_curves(metrics, str(output_dir / 'metric_curves.png')) plot_combined_dashboard(metrics, str(output_dir / 'training_dashboard.png')) print(f"\n所有可视化已保存至: {output_dir}") if __name__ == '__main__': main()