From 626e5f869ec0cab1cc44a7f603dcaedb22fd4cce Mon Sep 17 00:00:00 2001 From: p82b7rtam <1761133400@qq.com> Date: Mon, 29 Dec 2025 17:44:03 +0800 Subject: [PATCH] Delete 'run_style_transfer.py' --- run_style_transfer.py | 263 ------------------------------------------ 1 file changed, 263 deletions(-) delete mode 100644 run_style_transfer.py diff --git a/run_style_transfer.py b/run_style_transfer.py deleted file mode 100644 index 80f5be2..0000000 --- a/run_style_transfer.py +++ /dev/null @@ -1,263 +0,0 @@ -# scripts/extended_training.py -import warnings - -warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") - -import torch -import torch.nn as nn -import torch.optim as optim -from torchvision import models, transforms -from PIL import Image -from pathlib import Path -import yaml -import time -import numpy as np - - -print("=" * 60) - - -def main(): - # 加载配置 - config_path = Path(__file__).parent.parent / "configs" / "default.yaml" - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - # ========== 关键修复:大幅提高风格影响 ========== - config['model']['image_size'] = 256 - config['model']['content_weight'] = 1e3 # 降低内容权重 - config['model']['style_weight'] = 1e7 # 大幅提高风格权重 - config['model']['learning_rate'] = 0.003 - config['training']['num_steps'] = 400 - config['multi_scale']['enabled'] = False - - print("📋 长时间训练参数:") - print(f" - 图像尺寸: 256x256") - print(f" - 内容权重: 1e3 (风格权重的1/10000)") - print(f" - 风格权重: 1e7 (内容权重的10000倍)") - print(f" - 权重比例: 1:10000 (强调风格)") - print(f" - 学习率: 0.003") - print(f" - 训练步数: 400 (约5分钟)") - - # 设备设置 - device = torch.device("cpu") - torch.set_num_threads(2) - - # 加载图像 - project_root = Path(__file__).parent.parent - content_path = project_root / config['paths']['content_image'] - style_path = project_root / config['paths']['style_image'] - output_dir = project_root / config['paths']['output_dir'] - - print(f"\n📷 加载图像...") - content_img = Image.open(content_path).convert('RGB') - style_img = Image.open(style_path).convert('RGB') - - image_size = 256 - content_img = content_img.resize((image_size, image_size), Image.Resampling.LANCZOS) - style_img = style_img.resize((image_size, image_size), Image.Resampling.LANCZOS) - - print(f" 调整到: {image_size}x{image_size}") - - # 预处理(移除标准化) - preprocess = transforms.Compose([ - transforms.ToTensor(), - # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 注释掉 - ]) - - content_tensor = preprocess(content_img).unsqueeze(0).to(device) - style_tensor = preprocess(style_img).unsqueeze(0).to(device) - target_tensor = content_tensor.clone().requires_grad_(True) - - # 加载VGG16模型 - print("\n🔧 加载VGG16模型...") - start_time = time.time() - vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval() - - for param in vgg.parameters(): - param.requires_grad_(False) - - print(f"✅ 模型加载完成 ({time.time() - start_time:.2f}秒)") - - # 定义层索引 - content_layers = ["21"] # conv4_2 - style_layers = ["0", "5", "10", "19", "28"] # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1 - - # 特征提取函数 - def get_features_fixed(x, layers): - features = {} - x_tmp = x - layers_int = set([int(l) for l in layers]) - - for idx, (name, layer) in enumerate(vgg._modules.items()): - x_tmp = layer(x_tmp) - if idx in layers_int: - features[str(idx)] = x_tmp - - return features - - # 提取特征 - print("\n📊 提取特征...") - start_time = time.time() - - # 标准化输入以匹配VGG的预训练权重 - def normalize_for_vgg(tensor): - mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) - std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) - return (tensor - mean) / std - - content_features = get_features_fixed(normalize_for_vgg(content_tensor), content_layers) - style_features = get_features_fixed(normalize_for_vgg(style_tensor), style_layers) - - # 修改Gram矩阵计算(关键修复) - def gram_matrix_fixed(tensor): - b, c, h, w = tensor.size() - - # 展平特征 - tensor_flat = tensor.view(b, c, -1) # [b, c, h*w] - - # 计算Gram矩阵 - gram = torch.bmm(tensor_flat, tensor_flat.transpose(1, 2)) # [b, c, c] - - # 关键修改:只除以h*w,而不是c*h*w(原始论文做法) - gram = gram / (h * w) - - return gram - - # 预计算风格Gram矩阵 - style_grams = {} - for layer, feat in style_features.items(): - style_grams[layer] = gram_matrix_fixed(feat) - print(f" 层 {layer} Gram矩阵形状: {style_grams[layer].shape}") - - print(f"✅ 特征提取完成 ({time.time() - start_time:.2f}秒)") - - # 优化器设置 - learning_rate = 0.003 # 使用更高的学习率 - optimizer = optim.Adam([target_tensor], lr=learning_rate) - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) - - # 训练参数 - content_weight = 1e3 - style_weight = 1e7 # 大幅提高 - num_steps = 400 - log_interval = 40 - - print(f"\n🎨 开始长时间训练 (约5分钟)...") - print(f" 优化器: Adam (初始lr={learning_rate})") - print(f" 内容权重: {content_weight}") - print(f" 风格权重: {style_weight} (非常高,强调风格转移)") - print(f" 权重比例: 1:{int(style_weight / content_weight)}") - - start_time = time.time() - - # 记录损失历史 - loss_history = [] - content_loss_history = [] - style_loss_history = [] - - for step in range(1, num_steps + 1): - optimizer.zero_grad() - - # 提取目标特征(需要标准化) - target_normalized = normalize_for_vgg(target_tensor) - target_features = get_features_fixed(target_normalized, content_layers + style_layers) - - # 计算内容损失 - content_loss = torch.tensor(0.0, device=device) - for layer in content_layers: - if layer in target_features and layer in content_features: - diff = target_features[layer] - content_features[layer] - content_loss += torch.mean(diff ** 2) - - # 计算风格损失 - style_loss = torch.tensor(0.0, device=device) - for layer in style_layers: - if layer in target_features and layer in style_grams: - target_gram = gram_matrix_fixed(target_features[layer]) - style_gram = style_grams[layer] - diff = target_gram - style_gram - style_loss += torch.mean(diff ** 2) - - # 总损失 - total_loss = content_weight * content_loss + style_weight * style_loss - - # 反向传播 - total_loss.backward() - - # 梯度裁剪(提高限制) - torch.nn.utils.clip_grad_norm_([target_tensor], max_norm=10.0) - - optimizer.step() - - # 更新学习率 - if step % 100 == 0: - scheduler.step() - current_lr = optimizer.param_groups[0]['lr'] - print(f"📉 第 {step} 步: 学习率降低到 {current_lr:.6f}") - - # 限制像素值 - with torch.no_grad(): - target_tensor.data.clamp_(0, 1) - - # 记录损失 - loss_history.append(total_loss.item()) - content_loss_history.append(content_loss.item()) - style_loss_history.append(style_loss.item()) - - # 定期输出 - if step % log_interval == 0 or step == num_steps: - elapsed = time.time() - start_time - steps_per_second = step / elapsed if elapsed > 0 else 0 - remaining_steps = num_steps - step - estimated_remaining = remaining_steps / steps_per_second if steps_per_second > 0 else 0 - - # 计算风格损失贡献 - style_contribution = style_weight * style_loss.item() - content_contribution = content_weight * content_loss.item() - style_ratio = style_contribution / (style_contribution + content_contribution) * 100 - - print(f"⏳ 第 {step:3d}/{num_steps} | " - f"总损失: {total_loss.item():8.2f} | " - f"内容: {content_loss.item():6.4f} | " - f"风格: {style_loss.item():6.4f} | " - f"风格贡献: {style_ratio:.1f}% | " - f"用时: {elapsed:.1f}s | " - f"剩余: {estimated_remaining:.1f}s") - - total_time = time.time() - start_time - - # 分析损失变化 - print(f"\n📈 损失分析:") - print(f" 总损失变化: {loss_history[0]:.2f} → {loss_history[-1]:.2f}") - print(f" 内容损失变化: {content_loss_history[0]:.4f} → {content_loss_history[-1]:.4f}") - print(f" 风格损失变化: {style_loss_history[0]:.4f} → {style_loss_history[-1]:.4f}") - - # 保存结果 - postprocess = transforms.Compose([ - # transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], - # std=[1/0.229, 1/0.224, 1/0.225]), # 如果之前没标准化,这里也不需要 - transforms.ToPILImage() - ]) - - output_dir.mkdir(parents=True, exist_ok=True) - with torch.no_grad(): - result_img = postprocess(target_tensor.squeeze(0).cpu()) - - timestamp = time.strftime("%H%M%S") - output_path = output_dir / f"extended_result_{timestamp}.jpg" - result_img.save(output_path, quality=95, optimize=True) - - print(f"\n🎉 训练完成!") - print(f" 总用时: {total_time:.1f}秒") - print(f" 结果保存至: {output_path}") - - # 尝试显示结果 - try: - result_img.show(title="风格迁移结果") - except: - print(" ℹ️ 请手动打开图像查看") - - -if __name__ == "__main__": - main() \ No newline at end of file