You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 lines
796 B
26 lines
796 B
import yaml
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
from torchvision.utils import save_image
|
|
import os
|
|
|
|
def load_config(config_path):
|
|
"""加载配置文件"""
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
return yaml.safe_load(f)
|
|
|
|
def save_reconstructed_images(original, reconstructed, path, nrow=8):
|
|
"""保存重建图像对比"""
|
|
comparison = torch.cat([original[:nrow], reconstructed[:nrow]])
|
|
save_image(comparison.cpu(), path, nrow=nrow)
|
|
|
|
def plot_losses(train_losses, test_losses, save_path):
|
|
"""绘制损失曲线"""
|
|
plt.figure(figsize=(10, 5))
|
|
plt.plot(train_losses, label='Train Loss')
|
|
plt.plot(test_losses, label='Test Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
plt.savefig(save_path)
|
|
plt.close() |