From c69d17d7733b5913e600a5dac1ee628f6f1a0711 Mon Sep 17 00:00:00 2001 From: p82b7rtam <1761133400@qq.com> Date: Mon, 29 Dec 2025 16:37:54 +0800 Subject: [PATCH] ADD file via upload --- run_style_transfer.py | 263 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 run_style_transfer.py diff --git a/run_style_transfer.py b/run_style_transfer.py new file mode 100644 index 0000000..80f5be2 --- /dev/null +++ b/run_style_transfer.py @@ -0,0 +1,263 @@ +# scripts/extended_training.py +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") + +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models, transforms +from PIL import Image +from pathlib import Path +import yaml +import time +import numpy as np + + +print("=" * 60) + + +def main(): + # 加载配置 + config_path = Path(__file__).parent.parent / "configs" / "default.yaml" + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + # ========== 关键修复:大幅提高风格影响 ========== + config['model']['image_size'] = 256 + config['model']['content_weight'] = 1e3 # 降低内容权重 + config['model']['style_weight'] = 1e7 # 大幅提高风格权重 + config['model']['learning_rate'] = 0.003 + config['training']['num_steps'] = 400 + config['multi_scale']['enabled'] = False + + print("📋 长时间训练参数:") + print(f" - 图像尺寸: 256x256") + print(f" - 内容权重: 1e3 (风格权重的1/10000)") + print(f" - 风格权重: 1e7 (内容权重的10000倍)") + print(f" - 权重比例: 1:10000 (强调风格)") + print(f" - 学习率: 0.003") + print(f" - 训练步数: 400 (约5分钟)") + + # 设备设置 + device = torch.device("cpu") + torch.set_num_threads(2) + + # 加载图像 + project_root = Path(__file__).parent.parent + content_path = project_root / config['paths']['content_image'] + style_path = project_root / config['paths']['style_image'] + output_dir = project_root / config['paths']['output_dir'] + + print(f"\n📷 加载图像...") + content_img = Image.open(content_path).convert('RGB') + style_img = Image.open(style_path).convert('RGB') + + image_size = 256 + content_img = content_img.resize((image_size, image_size), Image.Resampling.LANCZOS) + style_img = style_img.resize((image_size, image_size), Image.Resampling.LANCZOS) + + print(f" 调整到: {image_size}x{image_size}") + + # 预处理(移除标准化) + preprocess = transforms.Compose([ + transforms.ToTensor(), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 注释掉 + ]) + + content_tensor = preprocess(content_img).unsqueeze(0).to(device) + style_tensor = preprocess(style_img).unsqueeze(0).to(device) + target_tensor = content_tensor.clone().requires_grad_(True) + + # 加载VGG16模型 + print("\n🔧 加载VGG16模型...") + start_time = time.time() + vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval() + + for param in vgg.parameters(): + param.requires_grad_(False) + + print(f"✅ 模型加载完成 ({time.time() - start_time:.2f}秒)") + + # 定义层索引 + content_layers = ["21"] # conv4_2 + style_layers = ["0", "5", "10", "19", "28"] # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1 + + # 特征提取函数 + def get_features_fixed(x, layers): + features = {} + x_tmp = x + layers_int = set([int(l) for l in layers]) + + for idx, (name, layer) in enumerate(vgg._modules.items()): + x_tmp = layer(x_tmp) + if idx in layers_int: + features[str(idx)] = x_tmp + + return features + + # 提取特征 + print("\n📊 提取特征...") + start_time = time.time() + + # 标准化输入以匹配VGG的预训练权重 + def normalize_for_vgg(tensor): + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + return (tensor - mean) / std + + content_features = get_features_fixed(normalize_for_vgg(content_tensor), content_layers) + style_features = get_features_fixed(normalize_for_vgg(style_tensor), style_layers) + + # 修改Gram矩阵计算(关键修复) + def gram_matrix_fixed(tensor): + b, c, h, w = tensor.size() + + # 展平特征 + tensor_flat = tensor.view(b, c, -1) # [b, c, h*w] + + # 计算Gram矩阵 + gram = torch.bmm(tensor_flat, tensor_flat.transpose(1, 2)) # [b, c, c] + + # 关键修改:只除以h*w,而不是c*h*w(原始论文做法) + gram = gram / (h * w) + + return gram + + # 预计算风格Gram矩阵 + style_grams = {} + for layer, feat in style_features.items(): + style_grams[layer] = gram_matrix_fixed(feat) + print(f" 层 {layer} Gram矩阵形状: {style_grams[layer].shape}") + + print(f"✅ 特征提取完成 ({time.time() - start_time:.2f}秒)") + + # 优化器设置 + learning_rate = 0.003 # 使用更高的学习率 + optimizer = optim.Adam([target_tensor], lr=learning_rate) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) + + # 训练参数 + content_weight = 1e3 + style_weight = 1e7 # 大幅提高 + num_steps = 400 + log_interval = 40 + + print(f"\n🎨 开始长时间训练 (约5分钟)...") + print(f" 优化器: Adam (初始lr={learning_rate})") + print(f" 内容权重: {content_weight}") + print(f" 风格权重: {style_weight} (非常高,强调风格转移)") + print(f" 权重比例: 1:{int(style_weight / content_weight)}") + + start_time = time.time() + + # 记录损失历史 + loss_history = [] + content_loss_history = [] + style_loss_history = [] + + for step in range(1, num_steps + 1): + optimizer.zero_grad() + + # 提取目标特征(需要标准化) + target_normalized = normalize_for_vgg(target_tensor) + target_features = get_features_fixed(target_normalized, content_layers + style_layers) + + # 计算内容损失 + content_loss = torch.tensor(0.0, device=device) + for layer in content_layers: + if layer in target_features and layer in content_features: + diff = target_features[layer] - content_features[layer] + content_loss += torch.mean(diff ** 2) + + # 计算风格损失 + style_loss = torch.tensor(0.0, device=device) + for layer in style_layers: + if layer in target_features and layer in style_grams: + target_gram = gram_matrix_fixed(target_features[layer]) + style_gram = style_grams[layer] + diff = target_gram - style_gram + style_loss += torch.mean(diff ** 2) + + # 总损失 + total_loss = content_weight * content_loss + style_weight * style_loss + + # 反向传播 + total_loss.backward() + + # 梯度裁剪(提高限制) + torch.nn.utils.clip_grad_norm_([target_tensor], max_norm=10.0) + + optimizer.step() + + # 更新学习率 + if step % 100 == 0: + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + print(f"📉 第 {step} 步: 学习率降低到 {current_lr:.6f}") + + # 限制像素值 + with torch.no_grad(): + target_tensor.data.clamp_(0, 1) + + # 记录损失 + loss_history.append(total_loss.item()) + content_loss_history.append(content_loss.item()) + style_loss_history.append(style_loss.item()) + + # 定期输出 + if step % log_interval == 0 or step == num_steps: + elapsed = time.time() - start_time + steps_per_second = step / elapsed if elapsed > 0 else 0 + remaining_steps = num_steps - step + estimated_remaining = remaining_steps / steps_per_second if steps_per_second > 0 else 0 + + # 计算风格损失贡献 + style_contribution = style_weight * style_loss.item() + content_contribution = content_weight * content_loss.item() + style_ratio = style_contribution / (style_contribution + content_contribution) * 100 + + print(f"⏳ 第 {step:3d}/{num_steps} | " + f"总损失: {total_loss.item():8.2f} | " + f"内容: {content_loss.item():6.4f} | " + f"风格: {style_loss.item():6.4f} | " + f"风格贡献: {style_ratio:.1f}% | " + f"用时: {elapsed:.1f}s | " + f"剩余: {estimated_remaining:.1f}s") + + total_time = time.time() - start_time + + # 分析损失变化 + print(f"\n📈 损失分析:") + print(f" 总损失变化: {loss_history[0]:.2f} → {loss_history[-1]:.2f}") + print(f" 内容损失变化: {content_loss_history[0]:.4f} → {content_loss_history[-1]:.4f}") + print(f" 风格损失变化: {style_loss_history[0]:.4f} → {style_loss_history[-1]:.4f}") + + # 保存结果 + postprocess = transforms.Compose([ + # transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + # std=[1/0.229, 1/0.224, 1/0.225]), # 如果之前没标准化,这里也不需要 + transforms.ToPILImage() + ]) + + output_dir.mkdir(parents=True, exist_ok=True) + with torch.no_grad(): + result_img = postprocess(target_tensor.squeeze(0).cpu()) + + timestamp = time.strftime("%H%M%S") + output_path = output_dir / f"extended_result_{timestamp}.jpg" + result_img.save(output_path, quality=95, optimize=True) + + print(f"\n🎉 训练完成!") + print(f" 总用时: {total_time:.1f}秒") + print(f" 结果保存至: {output_path}") + + # 尝试显示结果 + try: + result_img.show(title="风格迁移结果") + except: + print(" ℹ️ 请手动打开图像查看") + + +if __name__ == "__main__": + main() \ No newline at end of file