You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 lines
10 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
训练过程可视化脚本
该脚本用于可视化One-Prompt模型的训练过程包括
1. 训练/验证损失曲线
2. IoU和Dice指标曲线
3. 学习率变化曲线
4. 分割结果可视化
Usage:
python scripts/visualize_training.py --log_dir logs/polyp_val_test_2025_12_16_23_52_30
"""
import os
import re
import argparse
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict
import matplotlib
matplotlib.use('Agg') # 非GUI后端
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def parse_log_file(log_path: str) -> Dict[str, List[float]]:
"""
解析训练日志文件。
Args:
log_path: 日志文件路径
Returns:
包含训练指标的字典
"""
metrics = {
'epoch': [],
'train_loss': [],
'val_loss': [],
'iou': [],
'dice': []
}
with open(log_path, 'r', encoding='utf-8') as f:
for line in f:
# 解析训练损失: Train loss: 0.455222487449646|| @ epoch 0.
train_match = re.search(r'Train loss: ([\d.]+)\|\| @ epoch (\d+)', line)
if train_match:
loss = float(train_match.group(1))
epoch = int(train_match.group(2))
if epoch >= len(metrics['train_loss']):
metrics['train_loss'].append(loss)
metrics['epoch'].append(epoch)
# 解析验证指标: Total score: 0.367, IOU: 0.012, DICE: 0.022 || @ epoch 2.
val_match = re.search(
r'Total score: ([\d.]+), IOU: ([\d.]+), DICE: ([\d.]+) \|\| @ epoch (\d+)',
line
)
if val_match:
val_loss = float(val_match.group(1))
iou = float(val_match.group(2))
dice = float(val_match.group(3))
metrics['val_loss'].append(val_loss)
metrics['iou'].append(iou)
metrics['dice'].append(dice)
return metrics
def plot_loss_curves(metrics: Dict[str, List[float]], save_path: str) -> None:
"""
绘制训练和验证损失曲线。
Args:
metrics: 训练指标字典
save_path: 图像保存路径
"""
fig, ax = plt.subplots(figsize=(10, 6))
epochs = metrics['epoch']
train_loss = metrics['train_loss']
# 绘制训练损失
ax.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
# 如果有验证损失,绘制验证损失
if metrics['val_loss']:
# 验证是每隔几个epoch进行的需要对齐x轴
val_epochs = np.linspace(0, max(epochs), len(metrics['val_loss']))
ax.plot(val_epochs, metrics['val_loss'], 'r--', label='Validation Loss', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training and Validation Loss Curves', fontsize=14)
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
# 添加最佳损失标注
if train_loss:
min_idx = np.argmin(train_loss)
ax.annotate(f'Best: {train_loss[min_idx]:.4f}',
xy=(epochs[min_idx], train_loss[min_idx]),
xytext=(epochs[min_idx] + 2, train_loss[min_idx] + 0.1),
arrowprops=dict(arrowstyle='->', color='blue'),
fontsize=10, color='blue')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"损失曲线已保存至: {save_path}")
def plot_metric_curves(metrics: Dict[str, List[float]], save_path: str) -> None:
"""
绘制IoU和Dice指标曲线。
Args:
metrics: 训练指标字典
save_path: 图像保存路径
"""
if not metrics['iou'] or not metrics['dice']:
print("警告: 没有IoU/Dice指标数据")
return
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
val_epochs = range(len(metrics['iou']))
# IoU曲线
ax1.plot(val_epochs, metrics['iou'], 'g-o', label='IoU', linewidth=2, markersize=6)
ax1.set_xlabel('Validation Step', fontsize=12)
ax1.set_ylabel('IoU', fontsize=12)
ax1.set_title('Intersection over Union (IoU)', fontsize=14)
ax1.grid(True, alpha=0.3)
# 标注最佳IoU
if metrics['iou']:
max_idx = np.argmax(metrics['iou'])
max_iou = metrics['iou'][max_idx]
ax1.annotate(f'Best: {max_iou:.4f}',
xy=(max_idx, max_iou),
xytext=(max_idx + 0.5, max_iou + 0.01),
arrowprops=dict(arrowstyle='->', color='green'),
fontsize=10, color='green')
# Dice曲线
ax2.plot(val_epochs, metrics['dice'], 'm-s', label='Dice', linewidth=2, markersize=6)
ax2.set_xlabel('Validation Step', fontsize=12)
ax2.set_ylabel('Dice Score', fontsize=12)
ax2.set_title('Dice Coefficient', fontsize=14)
ax2.grid(True, alpha=0.3)
# 标注最佳Dice
if metrics['dice']:
max_idx = np.argmax(metrics['dice'])
max_dice = metrics['dice'][max_idx]
ax2.annotate(f'Best: {max_dice:.4f}',
xy=(max_idx, max_dice),
xytext=(max_idx + 0.5, max_dice + 0.01),
arrowprops=dict(arrowstyle='->', color='purple'),
fontsize=10, color='purple')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"指标曲线已保存至: {save_path}")
def plot_combined_dashboard(metrics: Dict[str, List[float]], save_path: str) -> None:
"""
绘制综合训练仪表板。
Args:
metrics: 训练指标字典
save_path: 图像保存路径
"""
fig = plt.figure(figsize=(16, 10))
# 创建子图布局
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
# 1. 训练损失曲线
ax1 = fig.add_subplot(gs[0, 0])
if metrics['train_loss']:
ax1.plot(metrics['epoch'], metrics['train_loss'], 'b-', linewidth=2)
ax1.fill_between(metrics['epoch'], metrics['train_loss'], alpha=0.3)
ax1.set_title('Training Loss', fontsize=12, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)
# 2. 验证损失曲线
ax2 = fig.add_subplot(gs[0, 1])
if metrics['val_loss']:
ax2.plot(range(len(metrics['val_loss'])), metrics['val_loss'], 'r-', linewidth=2)
ax2.fill_between(range(len(metrics['val_loss'])), metrics['val_loss'], alpha=0.3, color='red')
ax2.set_title('Validation Loss', fontsize=12, fontweight='bold')
ax2.set_xlabel('Validation Step')
ax2.set_ylabel('Loss')
ax2.grid(True, alpha=0.3)
# 3. IoU曲线
ax3 = fig.add_subplot(gs[0, 2])
if metrics['iou']:
ax3.plot(range(len(metrics['iou'])), metrics['iou'], 'g-o', linewidth=2, markersize=4)
ax3.set_title('IoU Score', fontsize=12, fontweight='bold')
ax3.set_xlabel('Validation Step')
ax3.set_ylabel('IoU')
ax3.grid(True, alpha=0.3)
# 4. Dice曲线
ax4 = fig.add_subplot(gs[1, 0])
if metrics['dice']:
ax4.plot(range(len(metrics['dice'])), metrics['dice'], 'm-s', linewidth=2, markersize=4)
ax4.set_title('Dice Score', fontsize=12, fontweight='bold')
ax4.set_xlabel('Validation Step')
ax4.set_ylabel('Dice')
ax4.grid(True, alpha=0.3)
# 5. 损失对比柱状图
ax5 = fig.add_subplot(gs[1, 1])
if metrics['train_loss'] and metrics['val_loss']:
x = np.arange(2)
values = [np.mean(metrics['train_loss']), np.mean(metrics['val_loss'])]
bars = ax5.bar(x, values, color=['blue', 'red'], alpha=0.7)
ax5.set_xticks(x)
ax5.set_xticklabels(['Avg Train Loss', 'Avg Val Loss'])
ax5.set_title('Average Loss Comparison', fontsize=12, fontweight='bold')
# 添加数值标签
for bar, val in zip(bars, values):
ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.4f}', ha='center', va='bottom', fontsize=10)
# 6. 训练统计信息
ax6 = fig.add_subplot(gs[1, 2])
ax6.axis('off')
# 计算统计信息
stats_text = "Training Statistics\n" + "="*30 + "\n\n"
if metrics['train_loss']:
stats_text += f"Total Epochs: {len(metrics['epoch'])}\n"
stats_text += f"Final Train Loss: {metrics['train_loss'][-1]:.4f}\n"
stats_text += f"Best Train Loss: {min(metrics['train_loss']):.4f}\n"
if metrics['val_loss']:
stats_text += f"Final Val Loss: {metrics['val_loss'][-1]:.4f}\n"
stats_text += f"Best Val Loss: {min(metrics['val_loss']):.4f}\n"
if metrics['iou']:
stats_text += f"Best IoU: {max(metrics['iou']):.4f}\n"
if metrics['dice']:
stats_text += f"Best Dice: {max(metrics['dice']):.4f}\n"
ax6.text(0.1, 0.5, stats_text, transform=ax6.transAxes, fontsize=11,
verticalalignment='center', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))
# 添加总标题
fig.suptitle('One-Prompt Medical Image Segmentation - Training Dashboard',
fontsize=16, fontweight='bold', y=0.98)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"训练仪表板已保存至: {save_path}")
def main():
"""主函数。"""
parser = argparse.ArgumentParser(description='可视化训练过程')
parser.add_argument('--log_dir', type=str, required=True, help='日志目录路径')
parser.add_argument('--output_dir', type=str, default=None, help='输出目录')
args = parser.parse_args()
log_dir = Path(args.log_dir)
output_dir = Path(args.output_dir) if args.output_dir else log_dir / 'visualizations'
output_dir.mkdir(parents=True, exist_ok=True)
# 查找日志文件
log_files = list(log_dir.glob('Log/*.log'))
if not log_files:
print(f"错误: 在 {log_dir}/Log/ 目录下未找到日志文件")
return
log_path = log_files[0]
print(f"正在解析日志文件: {log_path}")
# 解析日志
metrics = parse_log_file(str(log_path))
print(f"解析完成: {len(metrics['epoch'])} 个epoch, "
f"{len(metrics['val_loss'])} 次验证")
# 生成可视化
plot_loss_curves(metrics, str(output_dir / 'loss_curves.png'))
plot_metric_curves(metrics, str(output_dir / 'metric_curves.png'))
plot_combined_dashboard(metrics, str(output_dir / 'training_dashboard.png'))
print(f"\n所有可视化已保存至: {output_dir}")
if __name__ == '__main__':
main()