|
|
# scripts/extended_training.py
|
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torchvision import models, transforms
|
|
|
from PIL import Image
|
|
|
from pathlib import Path
|
|
|
import yaml
|
|
|
import time
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
def main():
|
|
|
# 加载配置
|
|
|
config_path = Path(__file__).parent.parent / "configs" / "default.yaml"
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
|
# ========== 关键修复:大幅提高风格影响 ==========
|
|
|
config['model']['image_size'] = 256
|
|
|
config['model']['content_weight'] = 1e3 # 降低内容权重
|
|
|
config['model']['style_weight'] = 1e7 # 大幅提高风格权重
|
|
|
config['model']['learning_rate'] = 0.003
|
|
|
config['training']['num_steps'] = 400
|
|
|
config['multi_scale']['enabled'] = False
|
|
|
|
|
|
print("📋 长时间训练参数:")
|
|
|
print(f" - 图像尺寸: 256x256")
|
|
|
print(f" - 内容权重: 1e3 (风格权重的1/10000)")
|
|
|
print(f" - 风格权重: 1e7 (内容权重的10000倍)")
|
|
|
print(f" - 权重比例: 1:10000 (强调风格)")
|
|
|
print(f" - 学习率: 0.003")
|
|
|
print(f" - 训练步数: 400 (约5分钟)")
|
|
|
|
|
|
# 设备设置
|
|
|
device = torch.device("cpu")
|
|
|
torch.set_num_threads(2)
|
|
|
|
|
|
# 加载图像
|
|
|
project_root = Path(__file__).parent.parent
|
|
|
content_path = project_root / config['paths']['content_image']
|
|
|
style_path = project_root / config['paths']['style_image']
|
|
|
output_dir = project_root / config['paths']['output_dir']
|
|
|
|
|
|
print(f"\n📷 加载图像...")
|
|
|
content_img = Image.open(content_path).convert('RGB')
|
|
|
style_img = Image.open(style_path).convert('RGB')
|
|
|
|
|
|
image_size = 256
|
|
|
content_img = content_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
|
|
style_img = style_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
|
|
|
|
|
print(f" 调整到: {image_size}x{image_size}")
|
|
|
|
|
|
# 预处理(移除标准化)
|
|
|
preprocess = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 注释掉
|
|
|
])
|
|
|
|
|
|
content_tensor = preprocess(content_img).unsqueeze(0).to(device)
|
|
|
style_tensor = preprocess(style_img).unsqueeze(0).to(device)
|
|
|
target_tensor = content_tensor.clone().requires_grad_(True)
|
|
|
|
|
|
# 加载VGG16模型
|
|
|
print("\n🔧 加载VGG16模型...")
|
|
|
start_time = time.time()
|
|
|
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
|
|
|
|
|
|
for param in vgg.parameters():
|
|
|
param.requires_grad_(False)
|
|
|
|
|
|
print(f"✅ 模型加载完成 ({time.time() - start_time:.2f}秒)")
|
|
|
|
|
|
# 定义层索引
|
|
|
content_layers = ["21"] # conv4_2
|
|
|
style_layers = ["0", "5", "10", "19", "28"] # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
|
|
|
|
|
|
# 特征提取函数
|
|
|
def get_features_fixed(x, layers):
|
|
|
features = {}
|
|
|
x_tmp = x
|
|
|
layers_int = set([int(l) for l in layers])
|
|
|
|
|
|
for idx, (name, layer) in enumerate(vgg._modules.items()):
|
|
|
x_tmp = layer(x_tmp)
|
|
|
if idx in layers_int:
|
|
|
features[str(idx)] = x_tmp
|
|
|
|
|
|
return features
|
|
|
|
|
|
# 提取特征
|
|
|
print("\n📊 提取特征...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
# 标准化输入以匹配VGG的预训练权重
|
|
|
def normalize_for_vgg(tensor):
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
|
|
return (tensor - mean) / std
|
|
|
|
|
|
content_features = get_features_fixed(normalize_for_vgg(content_tensor), content_layers)
|
|
|
style_features = get_features_fixed(normalize_for_vgg(style_tensor), style_layers)
|
|
|
|
|
|
# 修改Gram矩阵计算(关键修复)
|
|
|
def gram_matrix_fixed(tensor):
|
|
|
b, c, h, w = tensor.size()
|
|
|
|
|
|
# 展平特征
|
|
|
tensor_flat = tensor.view(b, c, -1) # [b, c, h*w]
|
|
|
|
|
|
# 计算Gram矩阵
|
|
|
gram = torch.bmm(tensor_flat, tensor_flat.transpose(1, 2)) # [b, c, c]
|
|
|
|
|
|
# 关键修改:只除以h*w,而不是c*h*w(原始论文做法)
|
|
|
gram = gram / (h * w)
|
|
|
|
|
|
return gram
|
|
|
|
|
|
# 预计算风格Gram矩阵
|
|
|
style_grams = {}
|
|
|
for layer, feat in style_features.items():
|
|
|
style_grams[layer] = gram_matrix_fixed(feat)
|
|
|
print(f" 层 {layer} Gram矩阵形状: {style_grams[layer].shape}")
|
|
|
|
|
|
print(f"✅ 特征提取完成 ({time.time() - start_time:.2f}秒)")
|
|
|
|
|
|
# 优化器设置
|
|
|
learning_rate = 0.003 # 使用更高的学习率
|
|
|
optimizer = optim.Adam([target_tensor], lr=learning_rate)
|
|
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
|
|
|
|
|
|
# 训练参数
|
|
|
content_weight = 1e3
|
|
|
style_weight = 1e7 # 大幅提高
|
|
|
num_steps = 400
|
|
|
log_interval = 40
|
|
|
|
|
|
print(f"\n🎨 开始长时间训练 (约5分钟)...")
|
|
|
print(f" 优化器: Adam (初始lr={learning_rate})")
|
|
|
print(f" 内容权重: {content_weight}")
|
|
|
print(f" 风格权重: {style_weight} (非常高,强调风格转移)")
|
|
|
print(f" 权重比例: 1:{int(style_weight / content_weight)}")
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
# 记录损失历史
|
|
|
loss_history = []
|
|
|
content_loss_history = []
|
|
|
style_loss_history = []
|
|
|
|
|
|
for step in range(1, num_steps + 1):
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
# 提取目标特征(需要标准化)
|
|
|
target_normalized = normalize_for_vgg(target_tensor)
|
|
|
target_features = get_features_fixed(target_normalized, content_layers + style_layers)
|
|
|
|
|
|
# 计算内容损失
|
|
|
content_loss = torch.tensor(0.0, device=device)
|
|
|
for layer in content_layers:
|
|
|
if layer in target_features and layer in content_features:
|
|
|
diff = target_features[layer] - content_features[layer]
|
|
|
content_loss += torch.mean(diff ** 2)
|
|
|
|
|
|
# 计算风格损失
|
|
|
style_loss = torch.tensor(0.0, device=device)
|
|
|
for layer in style_layers:
|
|
|
if layer in target_features and layer in style_grams:
|
|
|
target_gram = gram_matrix_fixed(target_features[layer])
|
|
|
style_gram = style_grams[layer]
|
|
|
diff = target_gram - style_gram
|
|
|
style_loss += torch.mean(diff ** 2)
|
|
|
|
|
|
# 总损失
|
|
|
total_loss = content_weight * content_loss + style_weight * style_loss
|
|
|
|
|
|
# 反向传播
|
|
|
total_loss.backward()
|
|
|
|
|
|
# 梯度裁剪(提高限制)
|
|
|
torch.nn.utils.clip_grad_norm_([target_tensor], max_norm=10.0)
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
# 更新学习率
|
|
|
if step % 100 == 0:
|
|
|
scheduler.step()
|
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
|
print(f"📉 第 {step} 步: 学习率降低到 {current_lr:.6f}")
|
|
|
|
|
|
# 限制像素值
|
|
|
with torch.no_grad():
|
|
|
target_tensor.data.clamp_(0, 1)
|
|
|
|
|
|
# 记录损失
|
|
|
loss_history.append(total_loss.item())
|
|
|
content_loss_history.append(content_loss.item())
|
|
|
style_loss_history.append(style_loss.item())
|
|
|
|
|
|
# 定期输出
|
|
|
if step % log_interval == 0 or step == num_steps:
|
|
|
elapsed = time.time() - start_time
|
|
|
steps_per_second = step / elapsed if elapsed > 0 else 0
|
|
|
remaining_steps = num_steps - step
|
|
|
estimated_remaining = remaining_steps / steps_per_second if steps_per_second > 0 else 0
|
|
|
|
|
|
# 计算风格损失贡献
|
|
|
style_contribution = style_weight * style_loss.item()
|
|
|
content_contribution = content_weight * content_loss.item()
|
|
|
style_ratio = style_contribution / (style_contribution + content_contribution) * 100
|
|
|
|
|
|
print(f"⏳ 第 {step:3d}/{num_steps} | "
|
|
|
f"总损失: {total_loss.item():8.2f} | "
|
|
|
f"内容: {content_loss.item():6.4f} | "
|
|
|
f"风格: {style_loss.item():6.4f} | "
|
|
|
f"风格贡献: {style_ratio:.1f}% | "
|
|
|
f"用时: {elapsed:.1f}s | "
|
|
|
f"剩余: {estimated_remaining:.1f}s")
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
|
|
# 分析损失变化
|
|
|
print(f"\n📈 损失分析:")
|
|
|
print(f" 总损失变化: {loss_history[0]:.2f} → {loss_history[-1]:.2f}")
|
|
|
print(f" 内容损失变化: {content_loss_history[0]:.4f} → {content_loss_history[-1]:.4f}")
|
|
|
print(f" 风格损失变化: {style_loss_history[0]:.4f} → {style_loss_history[-1]:.4f}")
|
|
|
|
|
|
# 保存结果
|
|
|
postprocess = transforms.Compose([
|
|
|
# transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
|
|
|
# std=[1/0.229, 1/0.224, 1/0.225]), # 如果之前没标准化,这里也不需要
|
|
|
transforms.ToPILImage()
|
|
|
])
|
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
with torch.no_grad():
|
|
|
result_img = postprocess(target_tensor.squeeze(0).cpu())
|
|
|
|
|
|
timestamp = time.strftime("%H%M%S")
|
|
|
output_path = output_dir / f"extended_result_{timestamp}.jpg"
|
|
|
result_img.save(output_path, quality=95, optimize=True)
|
|
|
|
|
|
print(f"\n🎉 训练完成!")
|
|
|
print(f" 总用时: {total_time:.1f}秒")
|
|
|
print(f" 结果保存至: {output_path}")
|
|
|
|
|
|
# 尝试显示结果
|
|
|
try:
|
|
|
result_img.show(title="风格迁移结果")
|
|
|
except:
|
|
|
print(" ℹ️ 请手动打开图像查看")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |