You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
NST/scripts/run_style_transfer.py

263 lines
9.2 KiB

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# scripts/extended_training.py
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import yaml
import time
import numpy as np
print("=" * 60)
def main():
# 加载配置
config_path = Path(__file__).parent.parent / "configs" / "default.yaml"
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# ========== 关键修复:大幅提高风格影响 ==========
config['model']['image_size'] = 256
config['model']['content_weight'] = 1e3 # 降低内容权重
config['model']['style_weight'] = 1e7 # 大幅提高风格权重
config['model']['learning_rate'] = 0.003
config['training']['num_steps'] = 400
config['multi_scale']['enabled'] = False
print("📋 长时间训练参数:")
print(f" - 图像尺寸: 256x256")
print(f" - 内容权重: 1e3 (风格权重的1/10000)")
print(f" - 风格权重: 1e7 (内容权重的10000倍)")
print(f" - 权重比例: 1:10000 (强调风格)")
print(f" - 学习率: 0.003")
print(f" - 训练步数: 400 (约5分钟)")
# 设备设置
device = torch.device("cpu")
torch.set_num_threads(2)
# 加载图像
project_root = Path(__file__).parent.parent
content_path = project_root / config['paths']['content_image']
style_path = project_root / config['paths']['style_image']
output_dir = project_root / config['paths']['output_dir']
print(f"\n📷 加载图像...")
content_img = Image.open(content_path).convert('RGB')
style_img = Image.open(style_path).convert('RGB')
image_size = 256
content_img = content_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
style_img = style_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
print(f" 调整到: {image_size}x{image_size}")
# 预处理(移除标准化)
preprocess = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 注释掉
])
content_tensor = preprocess(content_img).unsqueeze(0).to(device)
style_tensor = preprocess(style_img).unsqueeze(0).to(device)
target_tensor = content_tensor.clone().requires_grad_(True)
# 加载VGG16模型
print("\n🔧 加载VGG16模型...")
start_time = time.time()
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
param.requires_grad_(False)
print(f"✅ 模型加载完成 ({time.time() - start_time:.2f}秒)")
# 定义层索引
content_layers = ["21"] # conv4_2
style_layers = ["0", "5", "10", "19", "28"] # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
# 特征提取函数
def get_features_fixed(x, layers):
features = {}
x_tmp = x
layers_int = set([int(l) for l in layers])
for idx, (name, layer) in enumerate(vgg._modules.items()):
x_tmp = layer(x_tmp)
if idx in layers_int:
features[str(idx)] = x_tmp
return features
# 提取特征
print("\n📊 提取特征...")
start_time = time.time()
# 标准化输入以匹配VGG的预训练权重
def normalize_for_vgg(tensor):
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
return (tensor - mean) / std
content_features = get_features_fixed(normalize_for_vgg(content_tensor), content_layers)
style_features = get_features_fixed(normalize_for_vgg(style_tensor), style_layers)
# 修改Gram矩阵计算关键修复
def gram_matrix_fixed(tensor):
b, c, h, w = tensor.size()
# 展平特征
tensor_flat = tensor.view(b, c, -1) # [b, c, h*w]
# 计算Gram矩阵
gram = torch.bmm(tensor_flat, tensor_flat.transpose(1, 2)) # [b, c, c]
# 关键修改只除以h*w而不是c*h*w原始论文做法
gram = gram / (h * w)
return gram
# 预计算风格Gram矩阵
style_grams = {}
for layer, feat in style_features.items():
style_grams[layer] = gram_matrix_fixed(feat)
print(f"{layer} Gram矩阵形状: {style_grams[layer].shape}")
print(f"✅ 特征提取完成 ({time.time() - start_time:.2f}秒)")
# 优化器设置
learning_rate = 0.003 # 使用更高的学习率
optimizer = optim.Adam([target_tensor], lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
# 训练参数
content_weight = 1e3
style_weight = 1e7 # 大幅提高
num_steps = 400
log_interval = 40
print(f"\n🎨 开始长时间训练 (约5分钟)...")
print(f" 优化器: Adam (初始lr={learning_rate})")
print(f" 内容权重: {content_weight}")
print(f" 风格权重: {style_weight} (非常高,强调风格转移)")
print(f" 权重比例: 1:{int(style_weight / content_weight)}")
start_time = time.time()
# 记录损失历史
loss_history = []
content_loss_history = []
style_loss_history = []
for step in range(1, num_steps + 1):
optimizer.zero_grad()
# 提取目标特征(需要标准化)
target_normalized = normalize_for_vgg(target_tensor)
target_features = get_features_fixed(target_normalized, content_layers + style_layers)
# 计算内容损失
content_loss = torch.tensor(0.0, device=device)
for layer in content_layers:
if layer in target_features and layer in content_features:
diff = target_features[layer] - content_features[layer]
content_loss += torch.mean(diff ** 2)
# 计算风格损失
style_loss = torch.tensor(0.0, device=device)
for layer in style_layers:
if layer in target_features and layer in style_grams:
target_gram = gram_matrix_fixed(target_features[layer])
style_gram = style_grams[layer]
diff = target_gram - style_gram
style_loss += torch.mean(diff ** 2)
# 总损失
total_loss = content_weight * content_loss + style_weight * style_loss
# 反向传播
total_loss.backward()
# 梯度裁剪(提高限制)
torch.nn.utils.clip_grad_norm_([target_tensor], max_norm=10.0)
optimizer.step()
# 更新学习率
if step % 100 == 0:
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
print(f"📉 第 {step} 步: 学习率降低到 {current_lr:.6f}")
# 限制像素值
with torch.no_grad():
target_tensor.data.clamp_(0, 1)
# 记录损失
loss_history.append(total_loss.item())
content_loss_history.append(content_loss.item())
style_loss_history.append(style_loss.item())
# 定期输出
if step % log_interval == 0 or step == num_steps:
elapsed = time.time() - start_time
steps_per_second = step / elapsed if elapsed > 0 else 0
remaining_steps = num_steps - step
estimated_remaining = remaining_steps / steps_per_second if steps_per_second > 0 else 0
# 计算风格损失贡献
style_contribution = style_weight * style_loss.item()
content_contribution = content_weight * content_loss.item()
style_ratio = style_contribution / (style_contribution + content_contribution) * 100
print(f"⏳ 第 {step:3d}/{num_steps} | "
f"总损失: {total_loss.item():8.2f} | "
f"内容: {content_loss.item():6.4f} | "
f"风格: {style_loss.item():6.4f} | "
f"风格贡献: {style_ratio:.1f}% | "
f"用时: {elapsed:.1f}s | "
f"剩余: {estimated_remaining:.1f}s")
total_time = time.time() - start_time
# 分析损失变化
print(f"\n📈 损失分析:")
print(f" 总损失变化: {loss_history[0]:.2f}{loss_history[-1]:.2f}")
print(f" 内容损失变化: {content_loss_history[0]:.4f}{content_loss_history[-1]:.4f}")
print(f" 风格损失变化: {style_loss_history[0]:.4f}{style_loss_history[-1]:.4f}")
# 保存结果
postprocess = transforms.Compose([
# transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
# std=[1/0.229, 1/0.224, 1/0.225]), # 如果之前没标准化,这里也不需要
transforms.ToPILImage()
])
output_dir.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
result_img = postprocess(target_tensor.squeeze(0).cpu())
timestamp = time.strftime("%H%M%S")
output_path = output_dir / f"extended_result_{timestamp}.jpg"
result_img.save(output_path, quality=95, optimize=True)
print(f"\n🎉 训练完成!")
print(f" 总用时: {total_time:.1f}")
print(f" 结果保存至: {output_path}")
# 尝试显示结果
try:
result_img.show(title="风格迁移结果")
except:
print(" 请手动打开图像查看")
if __name__ == "__main__":
main()