|
|
#!/usr/bin/env python3
|
|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
训练过程可视化脚本
|
|
|
|
|
|
该脚本用于可视化One-Prompt模型的训练过程,包括:
|
|
|
1. 训练/验证损失曲线
|
|
|
2. IoU和Dice指标曲线
|
|
|
3. 学习率变化曲线
|
|
|
4. 分割结果可视化
|
|
|
|
|
|
Usage:
|
|
|
python scripts/visualize_training.py --log_dir logs/polyp_val_test_2025_12_16_23_52_30
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import re
|
|
|
import argparse
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
from typing import List, Tuple, Dict
|
|
|
import matplotlib
|
|
|
matplotlib.use('Agg') # 非GUI后端
|
|
|
|
|
|
# 设置中文字体
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
|
|
|
|
def parse_log_file(log_path: str) -> Dict[str, List[float]]:
|
|
|
"""
|
|
|
解析训练日志文件。
|
|
|
|
|
|
Args:
|
|
|
log_path: 日志文件路径
|
|
|
|
|
|
Returns:
|
|
|
包含训练指标的字典
|
|
|
"""
|
|
|
metrics = {
|
|
|
'epoch': [],
|
|
|
'train_loss': [],
|
|
|
'val_loss': [],
|
|
|
'iou': [],
|
|
|
'dice': []
|
|
|
}
|
|
|
|
|
|
with open(log_path, 'r', encoding='utf-8') as f:
|
|
|
for line in f:
|
|
|
# 解析训练损失: Train loss: 0.455222487449646|| @ epoch 0.
|
|
|
train_match = re.search(r'Train loss: ([\d.]+)\|\| @ epoch (\d+)', line)
|
|
|
if train_match:
|
|
|
loss = float(train_match.group(1))
|
|
|
epoch = int(train_match.group(2))
|
|
|
if epoch >= len(metrics['train_loss']):
|
|
|
metrics['train_loss'].append(loss)
|
|
|
metrics['epoch'].append(epoch)
|
|
|
|
|
|
# 解析验证指标: Total score: 0.367, IOU: 0.012, DICE: 0.022 || @ epoch 2.
|
|
|
val_match = re.search(
|
|
|
r'Total score: ([\d.]+), IOU: ([\d.]+), DICE: ([\d.]+) \|\| @ epoch (\d+)',
|
|
|
line
|
|
|
)
|
|
|
if val_match:
|
|
|
val_loss = float(val_match.group(1))
|
|
|
iou = float(val_match.group(2))
|
|
|
dice = float(val_match.group(3))
|
|
|
metrics['val_loss'].append(val_loss)
|
|
|
metrics['iou'].append(iou)
|
|
|
metrics['dice'].append(dice)
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
def plot_loss_curves(metrics: Dict[str, List[float]], save_path: str) -> None:
|
|
|
"""
|
|
|
绘制训练和验证损失曲线。
|
|
|
|
|
|
Args:
|
|
|
metrics: 训练指标字典
|
|
|
save_path: 图像保存路径
|
|
|
"""
|
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
|
|
|
|
epochs = metrics['epoch']
|
|
|
train_loss = metrics['train_loss']
|
|
|
|
|
|
# 绘制训练损失
|
|
|
ax.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
|
|
|
|
|
|
# 如果有验证损失,绘制验证损失
|
|
|
if metrics['val_loss']:
|
|
|
# 验证是每隔几个epoch进行的,需要对齐x轴
|
|
|
val_epochs = np.linspace(0, max(epochs), len(metrics['val_loss']))
|
|
|
ax.plot(val_epochs, metrics['val_loss'], 'r--', label='Validation Loss', linewidth=2)
|
|
|
|
|
|
ax.set_xlabel('Epoch', fontsize=12)
|
|
|
ax.set_ylabel('Loss', fontsize=12)
|
|
|
ax.set_title('Training and Validation Loss Curves', fontsize=14)
|
|
|
ax.legend(loc='upper right', fontsize=10)
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
|
|
# 添加最佳损失标注
|
|
|
if train_loss:
|
|
|
min_idx = np.argmin(train_loss)
|
|
|
ax.annotate(f'Best: {train_loss[min_idx]:.4f}',
|
|
|
xy=(epochs[min_idx], train_loss[min_idx]),
|
|
|
xytext=(epochs[min_idx] + 2, train_loss[min_idx] + 0.1),
|
|
|
arrowprops=dict(arrowstyle='->', color='blue'),
|
|
|
fontsize=10, color='blue')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
print(f"损失曲线已保存至: {save_path}")
|
|
|
|
|
|
|
|
|
def plot_metric_curves(metrics: Dict[str, List[float]], save_path: str) -> None:
|
|
|
"""
|
|
|
绘制IoU和Dice指标曲线。
|
|
|
|
|
|
Args:
|
|
|
metrics: 训练指标字典
|
|
|
save_path: 图像保存路径
|
|
|
"""
|
|
|
if not metrics['iou'] or not metrics['dice']:
|
|
|
print("警告: 没有IoU/Dice指标数据")
|
|
|
return
|
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
|
|
|
|
|
val_epochs = range(len(metrics['iou']))
|
|
|
|
|
|
# IoU曲线
|
|
|
ax1.plot(val_epochs, metrics['iou'], 'g-o', label='IoU', linewidth=2, markersize=6)
|
|
|
ax1.set_xlabel('Validation Step', fontsize=12)
|
|
|
ax1.set_ylabel('IoU', fontsize=12)
|
|
|
ax1.set_title('Intersection over Union (IoU)', fontsize=14)
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
# 标注最佳IoU
|
|
|
if metrics['iou']:
|
|
|
max_idx = np.argmax(metrics['iou'])
|
|
|
max_iou = metrics['iou'][max_idx]
|
|
|
ax1.annotate(f'Best: {max_iou:.4f}',
|
|
|
xy=(max_idx, max_iou),
|
|
|
xytext=(max_idx + 0.5, max_iou + 0.01),
|
|
|
arrowprops=dict(arrowstyle='->', color='green'),
|
|
|
fontsize=10, color='green')
|
|
|
|
|
|
# Dice曲线
|
|
|
ax2.plot(val_epochs, metrics['dice'], 'm-s', label='Dice', linewidth=2, markersize=6)
|
|
|
ax2.set_xlabel('Validation Step', fontsize=12)
|
|
|
ax2.set_ylabel('Dice Score', fontsize=12)
|
|
|
ax2.set_title('Dice Coefficient', fontsize=14)
|
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
|
|
# 标注最佳Dice
|
|
|
if metrics['dice']:
|
|
|
max_idx = np.argmax(metrics['dice'])
|
|
|
max_dice = metrics['dice'][max_idx]
|
|
|
ax2.annotate(f'Best: {max_dice:.4f}',
|
|
|
xy=(max_idx, max_dice),
|
|
|
xytext=(max_idx + 0.5, max_dice + 0.01),
|
|
|
arrowprops=dict(arrowstyle='->', color='purple'),
|
|
|
fontsize=10, color='purple')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
print(f"指标曲线已保存至: {save_path}")
|
|
|
|
|
|
|
|
|
def plot_combined_dashboard(metrics: Dict[str, List[float]], save_path: str) -> None:
|
|
|
"""
|
|
|
绘制综合训练仪表板。
|
|
|
|
|
|
Args:
|
|
|
metrics: 训练指标字典
|
|
|
save_path: 图像保存路径
|
|
|
"""
|
|
|
fig = plt.figure(figsize=(16, 10))
|
|
|
|
|
|
# 创建子图布局
|
|
|
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
|
|
|
|
|
|
# 1. 训练损失曲线
|
|
|
ax1 = fig.add_subplot(gs[0, 0])
|
|
|
if metrics['train_loss']:
|
|
|
ax1.plot(metrics['epoch'], metrics['train_loss'], 'b-', linewidth=2)
|
|
|
ax1.fill_between(metrics['epoch'], metrics['train_loss'], alpha=0.3)
|
|
|
ax1.set_title('Training Loss', fontsize=12, fontweight='bold')
|
|
|
ax1.set_xlabel('Epoch')
|
|
|
ax1.set_ylabel('Loss')
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
# 2. 验证损失曲线
|
|
|
ax2 = fig.add_subplot(gs[0, 1])
|
|
|
if metrics['val_loss']:
|
|
|
ax2.plot(range(len(metrics['val_loss'])), metrics['val_loss'], 'r-', linewidth=2)
|
|
|
ax2.fill_between(range(len(metrics['val_loss'])), metrics['val_loss'], alpha=0.3, color='red')
|
|
|
ax2.set_title('Validation Loss', fontsize=12, fontweight='bold')
|
|
|
ax2.set_xlabel('Validation Step')
|
|
|
ax2.set_ylabel('Loss')
|
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
|
|
# 3. IoU曲线
|
|
|
ax3 = fig.add_subplot(gs[0, 2])
|
|
|
if metrics['iou']:
|
|
|
ax3.plot(range(len(metrics['iou'])), metrics['iou'], 'g-o', linewidth=2, markersize=4)
|
|
|
ax3.set_title('IoU Score', fontsize=12, fontweight='bold')
|
|
|
ax3.set_xlabel('Validation Step')
|
|
|
ax3.set_ylabel('IoU')
|
|
|
ax3.grid(True, alpha=0.3)
|
|
|
|
|
|
# 4. Dice曲线
|
|
|
ax4 = fig.add_subplot(gs[1, 0])
|
|
|
if metrics['dice']:
|
|
|
ax4.plot(range(len(metrics['dice'])), metrics['dice'], 'm-s', linewidth=2, markersize=4)
|
|
|
ax4.set_title('Dice Score', fontsize=12, fontweight='bold')
|
|
|
ax4.set_xlabel('Validation Step')
|
|
|
ax4.set_ylabel('Dice')
|
|
|
ax4.grid(True, alpha=0.3)
|
|
|
|
|
|
# 5. 损失对比柱状图
|
|
|
ax5 = fig.add_subplot(gs[1, 1])
|
|
|
if metrics['train_loss'] and metrics['val_loss']:
|
|
|
x = np.arange(2)
|
|
|
values = [np.mean(metrics['train_loss']), np.mean(metrics['val_loss'])]
|
|
|
bars = ax5.bar(x, values, color=['blue', 'red'], alpha=0.7)
|
|
|
ax5.set_xticks(x)
|
|
|
ax5.set_xticklabels(['Avg Train Loss', 'Avg Val Loss'])
|
|
|
ax5.set_title('Average Loss Comparison', fontsize=12, fontweight='bold')
|
|
|
# 添加数值标签
|
|
|
for bar, val in zip(bars, values):
|
|
|
ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
|
|
f'{val:.4f}', ha='center', va='bottom', fontsize=10)
|
|
|
|
|
|
# 6. 训练统计信息
|
|
|
ax6 = fig.add_subplot(gs[1, 2])
|
|
|
ax6.axis('off')
|
|
|
|
|
|
# 计算统计信息
|
|
|
stats_text = "Training Statistics\n" + "="*30 + "\n\n"
|
|
|
if metrics['train_loss']:
|
|
|
stats_text += f"Total Epochs: {len(metrics['epoch'])}\n"
|
|
|
stats_text += f"Final Train Loss: {metrics['train_loss'][-1]:.4f}\n"
|
|
|
stats_text += f"Best Train Loss: {min(metrics['train_loss']):.4f}\n"
|
|
|
if metrics['val_loss']:
|
|
|
stats_text += f"Final Val Loss: {metrics['val_loss'][-1]:.4f}\n"
|
|
|
stats_text += f"Best Val Loss: {min(metrics['val_loss']):.4f}\n"
|
|
|
if metrics['iou']:
|
|
|
stats_text += f"Best IoU: {max(metrics['iou']):.4f}\n"
|
|
|
if metrics['dice']:
|
|
|
stats_text += f"Best Dice: {max(metrics['dice']):.4f}\n"
|
|
|
|
|
|
ax6.text(0.1, 0.5, stats_text, transform=ax6.transAxes, fontsize=11,
|
|
|
verticalalignment='center', fontfamily='monospace',
|
|
|
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5))
|
|
|
|
|
|
# 添加总标题
|
|
|
fig.suptitle('One-Prompt Medical Image Segmentation - Training Dashboard',
|
|
|
fontsize=16, fontweight='bold', y=0.98)
|
|
|
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
print(f"训练仪表板已保存至: {save_path}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""主函数。"""
|
|
|
parser = argparse.ArgumentParser(description='可视化训练过程')
|
|
|
parser.add_argument('--log_dir', type=str, required=True, help='日志目录路径')
|
|
|
parser.add_argument('--output_dir', type=str, default=None, help='输出目录')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
log_dir = Path(args.log_dir)
|
|
|
output_dir = Path(args.output_dir) if args.output_dir else log_dir / 'visualizations'
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
# 查找日志文件
|
|
|
log_files = list(log_dir.glob('Log/*.log'))
|
|
|
if not log_files:
|
|
|
print(f"错误: 在 {log_dir}/Log/ 目录下未找到日志文件")
|
|
|
return
|
|
|
|
|
|
log_path = log_files[0]
|
|
|
print(f"正在解析日志文件: {log_path}")
|
|
|
|
|
|
# 解析日志
|
|
|
metrics = parse_log_file(str(log_path))
|
|
|
print(f"解析完成: {len(metrics['epoch'])} 个epoch, "
|
|
|
f"{len(metrics['val_loss'])} 次验证")
|
|
|
|
|
|
# 生成可视化
|
|
|
plot_loss_curves(metrics, str(output_dir / 'loss_curves.png'))
|
|
|
plot_metric_curves(metrics, str(output_dir / 'metric_curves.png'))
|
|
|
plot_combined_dashboard(metrics, str(output_dir / 'training_dashboard.png'))
|
|
|
|
|
|
print(f"\n所有可视化已保存至: {output_dir}")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|