commit 5aea9c39c0894c0a0175e2c5d563d6bb44e8d617 Author: 张嘉欣 <1761133400@qq.com> Date: Mon Dec 29 17:42:47 2025 +0800 项目初始化:基于深度学习的AI风格迁移系统 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..359bb53 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..c920e16 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..7d9e181 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/style_transfer_ai.iml b/.idea/style_transfer_ai.iml new file mode 100644 index 0000000..e89a2b8 --- /dev/null +++ b/.idea/style_transfer_ai.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..77b61a5 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..146123b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +# .pre-commit-config.yaml +repos: + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + args: [--line-length=88] + + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: [--max-line-length=88] + + - repo: https://github.com/pycqa/isort + rev: 5.10.1 + hooks: + - id: isort + + - repo: https://github.com/python/mypy + rev: v1.3.0 + hooks: + - id: mypy + args: [--ignore-missing-imports] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..d294955 --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +# 🎨 AI 画风迁移器 (StyleTransferAI) + +一个新颖、轻量级的Python项目,旨在通过神经风格迁移算法,实现快速有趣的艺术风格转换效果。 + +## 🚀 快速开始 + +1. **克隆项目** + ```bash + git clone https://github.com/yourname/style_transfer_ai.git + cd style_transfer_ai1 + ``` + +2. **创建虚拟环境并安装依赖** + ```bash + python -m venv venv + source venv/bin/activate # Linux/Mac + # venv\Scripts\activate # Windows + pip install -r requirements.txt + ``` + +3. **准备图像** + 在 `data/raw/` 目录下放置你的内容图 (`content.jpg`) 和风格图 (`style.jpg`)。 + +4. **运行程序** + ```bash + python scripts/run_style_transfer.py + ``` + +5. **查看结果** + 结果将保存在 `data/processed/result_时间戳.jpg`。 + +## 💡 项目特点 + +- **高质量风格迁移**: 基于VGG16深度学习模型,生成高质量艺术效果。 +- **速度快**: CPU优化,无需GPU即可快速运行。 +- **可视化监控**: 实时显示训练损失,生成详细损失曲线图。 +- **高度可定制**: 修改 `configs/default.yaml` 或核心代码即可调整风格迁移效果。 +- **多模式训练**: 提供快速、扩展、高级三种训练模式。 + +## 📂 项目结构 + +``` +style_transfer_ai/ +│ +├── configs/ # 配置文件目录 +│ └── default.yaml # 默认超参数配置(学习率、迭代次数、权重等) +│ +├── data/ # 数据目录 +│ ├── raw/ # 原始输入图像 +│ │ ├── content.jpg # 内容图像(用户自定义) +│ │ └── style.jpg # 风格图像(用户自定义) +│ └── processed/ # 输出结果目录(自动生成) +│ └── result_20251202_2227.jpg # 示例输出(带时间戳) +│ +├── models/ # 模型相关代码 +│ └── vgg16_extractor.py # VGG16特征提取器封装 +│ +├── scripts/ # 可执行脚本 +│ └── run_style_transfer.py # 主程序入口 +│ +├── src/ # 核心源代码 +│ ├── style_transfer.py # 风格迁移核心算法实现 +│ ├── loss.py # 内容损失、风格损失计算 +│ ├── utils.py # 图像预处理、后处理、可视化工具 +│ └── trainer.py # 训练循环与优化器管理 +│ +├── outputs/ # 日志与可视化输出 +│ └── loss_curve_*.png # 损失变化曲线图(按运行时间命名) +│ +├── requirements.txt # Python依赖列表 +├── README.md # 项目说明文档(即本文件) +└── .gitignore # Git忽略规则 +``` \ No newline at end of file diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000..f25384e --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,36 @@ +# configs/stable_config.yaml +# 稳定训练配置 + +# 模型配置 +model: + model_type: "vgg16" + image_size: 192 + content_layers: ["19"] # conv4_2 + style_layers: ["0", "5", "10", "17", "24"] # 全部5个层 + content_weight: 1e4 # 增加内容权重 + style_weight: 5e4 # 降低风格权重 + learning_rate: 0.005 # 低学习率稳定训练 + +# 路径配置 +paths: + content_image: "data/raw/content.jpg" + style_image: "data/raw/tem.jpg" + output_dir: "data/processed" + +# 训练配置 +training: + num_steps: 150 # 减少步数 + log_interval: 15 + tv_weight: 1e-5 + +# 多尺度配置(先禁用,稳定后再启用) +multi_scale: + enabled: false + scales: [192] + scale_weights: [1.0] + +# 性能配置 +performance: + use_gpu: false + batch_size: 1 + precision: "float32" \ No newline at end of file diff --git a/data/processed/extended_result_133911.jpg b/data/processed/extended_result_133911.jpg new file mode 100644 index 0000000..3cf8281 Binary files /dev/null and b/data/processed/extended_result_133911.jpg differ diff --git a/data/processed/extended_result_174121.jpg b/data/processed/extended_result_174121.jpg new file mode 100644 index 0000000..4f086d2 Binary files /dev/null and b/data/processed/extended_result_174121.jpg differ diff --git a/data/processed/extended_result_220811.jpg b/data/processed/extended_result_220811.jpg new file mode 100644 index 0000000..5f400fe Binary files /dev/null and b/data/processed/extended_result_220811.jpg differ diff --git a/data/processed/tem.jpg b/data/processed/tem.jpg new file mode 100644 index 0000000..7efc2c1 Binary files /dev/null and b/data/processed/tem.jpg differ diff --git a/data/raw/content.jpg b/data/raw/content.jpg new file mode 100644 index 0000000..d67a5a7 Binary files /dev/null and b/data/raw/content.jpg differ diff --git a/data/raw/tem.jpg b/data/raw/tem.jpg new file mode 100644 index 0000000..203618f Binary files /dev/null and b/data/raw/tem.jpg differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ace7066 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "style-transfer-ai" +version = "0.1.0" +description = "A style transfer AI tool using neural networks" +authors = [ + {name = "Your Name", email = "your.email@example.com"}, +] +dependencies = [ + "torch>=2.0.0", + "torchvision>=0.15.0", + "numpy>=1.24.0", + "Pillow>=9.0.0", + "scipy>=1.10.0", + "PyYAML>=6.0.0", +] + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..37b46a9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +numpy>=1.24.0 +Pillow>=9.0.0 +scipy>=1.10.0 +PyYAML>=6.0.0 \ No newline at end of file diff --git a/scripts/run_style_transfer.py b/scripts/run_style_transfer.py new file mode 100644 index 0000000..d08e37d --- /dev/null +++ b/scripts/run_style_transfer.py @@ -0,0 +1,263 @@ +# scripts/extended_training.py +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") + +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models, transforms +from PIL import Image +from pathlib import Path +import yaml +import time +import numpy as np + + +print("=" * 60) + + +def main(): + # 加载配置 + config_path = Path(__file__).parent.parent / "configs" / "default.yaml" + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + # ========== 关键修复:大幅提高风格影响 ========== + config['model']['image_size'] = 256 + config['model']['content_weight'] = 1e3 # 降低内容权重 + config['model']['style_weight'] = 1e7 # 大幅提高风格权重 + config['model']['learning_rate'] = 0.003 + config['training']['num_steps'] = 400 + config['multi_scale']['enabled'] = False + + print("📋 长时间训练参数:") + print(f" - 图像尺寸: 256x256") + print(f" - 内容权重: 1e3 (风格权重的1/10000)") + print(f" - 风格权重: 1e7 (内容权重的10000倍)") + print(f" - 权重比例: 1:10000 (强调风格)") + print(f" - 学习率: 0.003") + print(f" - 训练步数: 400 (约5分钟)") + + # 设备设置 + device = torch.device("cpu") + torch.set_num_threads(2) + + # 加载图像 + project_root = Path(__file__).parent.parent + content_path = project_root / config['paths']['content_image'] + style_path = project_root / config['paths']['style_image'] + output_dir = project_root / config['paths']['output_dir'] + + print(f"\n📷 加载图像...") + content_img = Image.open(content_path).convert('RGB') + style_img = Image.open(style_path).convert('RGB') + + image_size = 256 + content_img = content_img.resize((image_size, image_size), Image.Resampling.LANCZOS) + style_img = style_img.resize((image_size, image_size), Image.Resampling.LANCZOS) + + print(f" 调整到: {image_size}x{image_size}") + + # 预处理(移除标准化) + preprocess = transforms.Compose([ + transforms.ToTensor(), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 注释掉 + ]) + + content_tensor = preprocess(content_img).unsqueeze(0).to(device) + style_tensor = preprocess(style_img).unsqueeze(0).to(device) + target_tensor = content_tensor.clone().requires_grad_(True) + + # 加载VGG16模型 + print("\n🔧 加载VGG16模型...") + start_time = time.time() + vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval() + + for param in vgg.parameters(): + param.requires_grad_(False) + + print(f"✅ 模型加载完成 ({time.time() - start_time:.2f}秒)") + + # 定义层索引 + content_layers = ["21"] # conv4_2 + style_layers = ["0", "5", "10", "19", "28"] # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1 + + # 特征提取函数 + def get_features_fixed(x, layers): + features = {} + x_tmp = x + layers_int = set([int(l) for l in layers]) + + for idx, (name, layer) in enumerate(vgg._modules.items()): + x_tmp = layer(x_tmp) + if idx in layers_int: + features[str(idx)] = x_tmp + + return features + + # 提取特征 + print("\n📊 提取特征...") + start_time = time.time() + + # 标准化输入以匹配VGG的预训练权重 + def normalize_for_vgg(tensor): + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + return (tensor - mean) / std + + content_features = get_features_fixed(normalize_for_vgg(content_tensor), content_layers) + style_features = get_features_fixed(normalize_for_vgg(style_tensor), style_layers) + + # 修改Gram矩阵计算(关键修复) + def gram_matrix_fixed(tensor): + b, c, h, w = tensor.size() + + # 展平特征 + tensor_flat = tensor.view(b, c, -1) # [b, c, h*w] + + # 计算Gram矩阵 + gram = torch.bmm(tensor_flat, tensor_flat.transpose(1, 2)) # [b, c, c] + + # 关键修改:只除以h*w,而不是c*h*w(原始论文做法) + gram = gram / (h * w) + + return gram + + # 预计算风格Gram矩阵 + style_grams = {} + for layer, feat in style_features.items(): + style_grams[layer] = gram_matrix_fixed(feat) + print(f" 层 {layer} Gram矩阵形状: {style_grams[layer].shape}") + + print(f"✅ 特征提取完成 ({time.time() - start_time:.2f}秒)") + + # 优化器设置 + learning_rate = 0.003 # 使用更高的学习率 + optimizer = optim.Adam([target_tensor], lr=learning_rate) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) + + # 训练参数 + content_weight = 1e3 + style_weight = 1e7 # 大幅提高 + num_steps = 400 + log_interval = 40 + + print(f"\n🎨 开始长时间训练 (约5分钟)...") + print(f" 优化器: Adam (初始lr={learning_rate})") + print(f" 内容权重: {content_weight}") + print(f" 风格权重: {style_weight} (非常高,强调风格转移)") + print(f" 权重比例: 1:{int(style_weight / content_weight)}") + + start_time = time.time() + + # 记录损失历史 + loss_history = [] + content_loss_history = [] + style_loss_history = [] + + for step in range(1, num_steps + 1): + optimizer.zero_grad() + + # 提取目标特征(需要标准化) + target_normalized = normalize_for_vgg(target_tensor) + target_features = get_features_fixed(target_normalized, content_layers + style_layers) + + # 计算内容损失 + content_loss = torch.tensor(0.0, device=device) + for layer in content_layers: + if layer in target_features and layer in content_features: + diff = target_features[layer] - content_features[layer] + content_loss += torch.mean(diff ** 2) + + # 计算风格损失 + style_loss = torch.tensor(0.0, device=device) + for layer in style_layers: + if layer in target_features and layer in style_grams: + target_gram = gram_matrix_fixed(target_features[layer]) + style_gram = style_grams[layer] + diff = target_gram - style_gram + style_loss += torch.mean(diff ** 2) + + # 总损失 + total_loss = content_weight * content_loss + style_weight * style_loss + + # 反向传播 + total_loss.backward() + + # 梯度裁剪(提高限制) + torch.nn.utils.clip_grad_norm_([target_tensor], max_norm=10.0) + + optimizer.step() + + # 更新学习率 + if step % 100 == 0: + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + print(f"📉 第 {step} 步: 学习率降低到 {current_lr:.6f}") + + # 限制像素值 + with torch.no_grad(): + target_tensor.data.clamp_(0, 1) + + # 记录损失 + loss_history.append(total_loss.item()) + content_loss_history.append(content_loss.item()) + style_loss_history.append(style_loss.item()) + + # 定期输出 + if step % log_interval == 0 or step == num_steps: + elapsed = time.time() - start_time + steps_per_second = step / elapsed if elapsed > 0 else 0 + remaining_steps = num_steps - step + estimated_remaining = remaining_steps / steps_per_second if steps_per_second > 0 else 0 + + # 计算风格损失贡献 + style_contribution = style_weight * style_loss.item() + content_contribution = content_weight * content_loss.item() + style_ratio = style_contribution / (style_contribution + content_contribution) * 100 + + print(f"⏳ 第 {step:3d}/{num_steps} | " + f"总损失: {total_loss.item():8.2f} | " + f"内容: {content_loss.item():6.4f} | " + f"风格: {style_loss.item():6.4f} | " + f"风格贡献: {style_ratio:.1f}% | " + f"用时: {elapsed:.1f}s | " + f"剩余: {estimated_remaining:.1f}s") + + total_time = time.time() - start_time + + # 分析损失变化 + print(f"\n📈 损失分析:") + print(f" 总损失变化: {loss_history[0]:.2f} → {loss_history[-1]:.2f}") + print(f" 内容损失变化: {content_loss_history[0]:.4f} → {content_loss_history[-1]:.4f}") + print(f" 风格损失变化: {style_loss_history[0]:.4f} → {style_loss_history[-1]:.4f}") + + # 保存结果 + postprocess = transforms.Compose([ + # transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + # std=[1/0.229, 1/0.224, 1/0.225]), # 如果之前没标准化,这里也不需要 + transforms.ToPILImage() + ]) + + output_dir.mkdir(parents=True, exist_ok=True) + with torch.no_grad(): + result_img = postprocess(target_tensor.squeeze(0).cpu()) + + timestamp = time.strftime("%H%M%S") + output_path = output_dir / f"extended_result_{timestamp}.jpg" + result_img.save(output_path, quality=95, optimize=True) + + print(f"\n🎉 训练完成!") + print(f" 总用时: {total_time:.1f}秒") + print(f" 结果保存至: {output_path}") + + # 尝试显示结果 + try: + result_img.show(title="风格迁移结果") + except: + print(" ℹ️ 请手动打开图像查看") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/style_transfer_ai1/core/__init__.py b/src/style_transfer_ai1/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/style_transfer_ai1/core/__pycache__/__init__.cpython-311.pyc b/src/style_transfer_ai1/core/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..970f3df Binary files /dev/null and b/src/style_transfer_ai1/core/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/style_transfer_ai1/core/__pycache__/feature_extractor.cpython-311.pyc b/src/style_transfer_ai1/core/__pycache__/feature_extractor.cpython-311.pyc new file mode 100644 index 0000000..8a6c53f Binary files /dev/null and b/src/style_transfer_ai1/core/__pycache__/feature_extractor.cpython-311.pyc differ diff --git a/src/style_transfer_ai1/core/__pycache__/image_synthesizer.cpython-311.pyc b/src/style_transfer_ai1/core/__pycache__/image_synthesizer.cpython-311.pyc new file mode 100644 index 0000000..303d064 Binary files /dev/null and b/src/style_transfer_ai1/core/__pycache__/image_synthesizer.cpython-311.pyc differ diff --git a/src/style_transfer_ai1/core/feature_extractor.py b/src/style_transfer_ai1/core/feature_extractor.py new file mode 100644 index 0000000..753a3d7 --- /dev/null +++ b/src/style_transfer_ai1/core/feature_extractor.py @@ -0,0 +1,103 @@ +""" +风格特征提取器模块 - 基础版本和传统方法。 +""" + +import numpy as np +from typing import Tuple, Dict, Any +from abc import ABC, abstractmethod + + +class BaseFeatureExtractor(ABC): + """特征提取器基类""" + + @abstractmethod + def extract(self, image: np.ndarray) -> Dict[str, Any]: + """提取特征抽象方法""" + pass + + @abstractmethod + def get_features(self) -> Dict[str, Any]: + """获取特征抽象方法""" + pass + + +class StyleFeatureExtractor(BaseFeatureExtractor): + """ + 传统风格特征提取器。 + 通过计算图像的直方图和局部纹理特征来定义"风格"。 + """ + + def __init__(self, config: Dict[str, Any]): + """ + 初始化提取器。 + + Args: + config: 配置字典,包含特征提取参数。 + """ + self.config = config + self.style_features: Dict[str, Any] = {} + + def extract(self, image: np.ndarray) -> Dict[str, Any]: + """ + 从图像中提取风格特征。 + + Args: + image: 输入图像,形状为 (H, W, C),值域 [0, 255]。 + + Returns: + 一个包含各种风格特征的字典。 + """ + # 确保图像在[0, 1]范围 + image_normalized = image.astype(np.float32) / 255.0 + + features = {} + + # 1. 颜色直方图 (RGB通道) + hist_r, _ = np.histogram(image_normalized[:, :, 0].ravel(), bins=32, range=(0, 1)) + hist_g, _ = np.histogram(image_normalized[:, :, 1].ravel(), bins=32, range=(0, 1)) + hist_b, _ = np.histogram(image_normalized[:, :, 2].ravel(), bins=32, range=(0, 1)) + features['color_histogram'] = np.concatenate([hist_r, hist_g, hist_b]) + + # 2. 局部纹理特征 + diff_h = np.abs(np.diff(image_normalized, axis=1)).mean() + diff_v = np.abs(np.diff(image_normalized, axis=0)).mean() + features['texture_strength'] = np.array([diff_h, diff_v]) + + # 3. 平均颜色 + features['mean_color'] = image_normalized.mean(axis=(0, 1)) + + # 4. 颜色标准差 + features['color_std'] = image_normalized.std(axis=(0, 1)) + + self.style_features = features + return features + + def get_features(self) -> Dict[str, Any]: + """ + 获取最近一次提取的风格特征。 + + Returns: + 风格特征字典。 + """ + return self.style_features.copy() + + +# 工厂函数,便于创建不同类型的提取器 +def create_feature_extractor(extractor_type: str, config: Dict[str, Any]) -> BaseFeatureExtractor: + """ + 创建特征提取器工厂函数。 + + Args: + extractor_type: 提取器类型 ('traditional' 或 'neural') + config: 配置字典 + + Returns: + 特征提取器实例 + """ + if extractor_type == 'traditional': + return StyleFeatureExtractor(config) + elif extractor_type == 'neural': + from .neural_feature_extractor import NeuralStyleFeatureExtractor + return NeuralStyleFeatureExtractor(config) + else: + raise ValueError(f"不支持的提取器类型: {extractor_type}") \ No newline at end of file diff --git a/src/style_transfer_ai1/core/image_synthesizer.py b/src/style_transfer_ai1/core/image_synthesizer.py new file mode 100644 index 0000000..8501906 --- /dev/null +++ b/src/style_transfer_ai1/core/image_synthesizer.py @@ -0,0 +1,171 @@ +""" +图像合成器模块 - 基础版本和传统方法。 +""" + +import numpy as np +from typing import Tuple, Dict, Any +from scipy.optimize import minimize +import warnings +from abc import ABC, abstractmethod + +warnings.filterwarnings("ignore", category=UserWarning) + + +class BaseImageSynthesizer(ABC): + """图像合成器基类""" + + @abstractmethod + def set_content(self, image: np.ndarray): + """设置内容图像""" + pass + + @abstractmethod + def set_style_features(self, features: Dict[str, Any]): + """设置风格特征""" + pass + + @abstractmethod + def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray: + """合成图像""" + pass + + +class ImageSynthesizer(BaseImageSynthesizer): + """ + 传统图像合成器,负责生成最终的风格化图像。 + """ + + def __init__(self, config: Dict[str, Any]): + """ + 初始化合成器。 + + Args: + config: 配置字典。 + """ + self.config = config + self.content_image: np.ndarray = None + self.style_features: Dict[str, Any] = {} + self.synthesized_image: np.ndarray = None + + def set_content(self, image: np.ndarray): + """ + 设置内容图像。 + + Args: + image: 内容图像,形状为 (H, W, C)。 + """ + self.content_image = image.astype(np.float32) / 255.0 + + def set_style_features(self, features: Dict[str, Any]): + """ + 设置风格特征。 + + Args: + features: 由 StyleFeatureExtractor 提取的特征字典。 + """ + self.style_features = features + + def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray: + """ + 合成风格化图像。 + + Args: + initial_image: 初始图像,如果为None,则使用内容图像。 + max_iter: 最大迭代次数。 + + Returns: + 合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。 + """ + if initial_image is None: + initial_image = self.content_image.copy() + + # 将图像展平为一维向量以便于优化 + h, w, c = self.content_image.shape + x0 = initial_image.ravel() + + # 定义损失函数 + def loss_function(x): + # 重塑回图像形状 + img = x.reshape((h, w, c)) + + # 1. 内容损失:保持与原图相似 + content_loss = np.mean((img - self.content_image) ** 2) + + # 2. 风格损失:匹配风格特征 + style_loss = self._calculate_style_loss(img) + + # 总损失 + total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss + + return total_loss + + # 执行优化 + result = minimize(loss_function, x0, method='L-BFGS-B', + options={'maxiter': max_iter, 'disp': False}) + + # 重塑结果并转换回[0, 255] + synthesized = result.x.reshape((h, w, c)) + synthesized = np.clip(synthesized, 0, 1) * 255 + self.synthesized_image = synthesized.astype(np.uint8) + + return self.synthesized_image + + def _calculate_style_loss(self, image: np.ndarray) -> float: + """ + 计算当前图像与目标风格特征之间的损失。 + + Args: + image: 当前合成的图像。 + + Returns: + 风格损失值。 + """ + loss = 0.0 + + # 1. 颜色直方图损失 + hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1)) + hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1)) + hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1)) + current_hist = np.concatenate([hist_r, hist_g, hist_b]) + target_hist = self.style_features['color_histogram'] + # 使用KL散度作为距离度量 + epsilon = 1e-10 + kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon))) + loss += kl_div + + # 2. 纹理强度损失 + diff_h = np.abs(np.diff(image, axis=1)).mean() + diff_v = np.abs(np.diff(image, axis=0)).mean() + current_texture = np.array([diff_h, diff_v]) + target_texture = self.style_features['texture_strength'] + texture_loss = np.mean((current_texture - target_texture) ** 2) + loss += texture_loss + + # 3. 平均颜色损失 + current_mean = image.mean(axis=(0, 1)) + target_mean = self.style_features['mean_color'] + mean_loss = np.mean((current_mean - target_mean) ** 2) + loss += mean_loss + + return loss + + +# 工厂函数 +def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer: + """ + 创建图像合成器工厂函数。 + + Args: + synthesizer_type: 合成器类型 ('traditional' 或 'neural') + config: 配置字典 + + Returns: + 图像合成器实例 + """ + if synthesizer_type == 'traditional': + return ImageSynthesizer(config) + elif synthesizer_type == 'neural': + from .neural_image_synthesizer import NeuralImageSynthesizer + return NeuralImageSynthesizer(config) + else: + raise ValueError(f"不支持的合成器类型: {synthesizer_type}") \ No newline at end of file diff --git a/src/style_transfer_ai1/core/multi_scale_processor.py b/src/style_transfer_ai1/core/multi_scale_processor.py new file mode 100644 index 0000000..303a8b2 --- /dev/null +++ b/src/style_transfer_ai1/core/multi_scale_processor.py @@ -0,0 +1,82 @@ +""" +多尺度风格迁移处理器。 +在不同尺度上处理图像以获得更好的结果。 +""" + +import torch +import torch.nn as nn +from torchvision import transforms +import numpy as np +from PIL import Image +from typing import Dict, Any, List +import cv2 + + +class MultiScaleProcessor: + """ + 多尺度图像处理器。 + """ + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.scales = config.get('scales', [512, 256]) # 从大到小 + self.scale_weights = config.get('scale_weights', [1.0, 0.5]) + + def process_multi_scale(self, + content_image: np.ndarray, + style_image: np.ndarray, + synthesizer: Any) -> np.ndarray: + """ + 多尺度风格迁移。 + + Args: + content_image: 内容图像 + style_image: 风格图像 + synthesizer: 图像合成器 + + Returns: + 合成后的图像 + """ + current_result = None + + for i, scale in enumerate(self.scales): + print(f"处理尺度: {scale}x{scale}") + + # 调整图像尺寸 + content_scaled = self._resize_image(content_image, scale) + style_scaled = self._resize_image(style_image, scale) + + # 如果是第一次迭代,使用内容图像作为初始值 + # 否则使用上一步的结果(调整到当前尺度) + if current_result is None: + initial_image = None + else: + initial_image = self._resize_image(current_result, scale) + + # 在当前尺度上合成 + # 这里需要根据你的具体合成器接口调整 + current_result = synthesizer.synthesize( + content_scaled, + style_scaled, + initial_image=initial_image, + scale_weight=self.scale_weights[i] + ) + + return current_result + + def _resize_image(self, image: np.ndarray, size: int) -> np.ndarray: + """调整图像尺寸""" + h, w = image.shape[:2] + + if h > w: + new_h = size + new_w = int(w * size / h) + else: + new_w = size + new_h = int(h * size / w) + + # 使用高质量的重采样方法 + pil_image = Image.fromarray(image) + resized = pil_image.resize((new_w, new_h), Image.LANCZOS) + + return np.array(resized) \ No newline at end of file diff --git a/src/style_transfer_ai1/core/neural_feature_extractor.py b/src/style_transfer_ai1/core/neural_feature_extractor.py new file mode 100644 index 0000000..01db088 --- /dev/null +++ b/src/style_transfer_ai1/core/neural_feature_extractor.py @@ -0,0 +1,126 @@ +""" +基于深度学习的风格特征提取器。 +使用预训练的CNN网络提取高级风格特征。 +""" + +import torch +import torch.nn as nn +import torchvision.models as models +import torchvision.transforms as transforms +from PIL import Image +import numpy as np +from typing import Dict, List, Any +import logging + +logger = logging.getLogger(__name__) + + +class NeuralStyleFeatureExtractor: + """ + 使用VGG网络提取风格特征的提取器。 + """ + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = self._load_model() + self.preprocess = self._get_preprocess() + + # 定义用于风格提取的层 + self.style_layers = { + 'vgg19': ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'], + 'vgg16': ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3'] + }[config.get('model_type', 'vgg19')] + + self.features = {} + + def _load_model(self): + """加载预训练的VGG模型""" + model_type = self.config.get('model_type', 'vgg19') + if model_type == 'vgg19': + model = models.vgg19(pretrained=True).features + else: + model = models.vgg16(pretrained=True).features + + # 冻结参数 + for param in model.parameters(): + param.requires_grad = False + + return model.to(self.device).eval() + + def _get_preprocess(self): + """获取图像预处理流程""" + return transforms.Compose([ + transforms.Resize(self.config.get('image_size', 512)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ]) + + def _hook_layer(self, layer_name): + """注册钩子来获取中间层输出""" + + def hook(module, input, output): + self.features[layer_name] = output + + return hook + + def _register_hooks(self): + """为风格层注册钩子""" + self.hooks = [] + layer_idx = 0 + + for name, layer in self.model.named_children(): + if isinstance(layer, nn.Conv2d): + layer_name = f"conv{layer_idx // 2 + 1}_{(layer_idx % 2) + 1}" + if layer_name in self.style_layers: + hook = layer.register_forward_hook(self._hook_layer(layer_name)) + self.hooks.append(hook) + layer_idx += 1 + + def _remove_hooks(self): + """移除所有钩子""" + for hook in self.hooks: + hook.remove() + + def gram_matrix(self, tensor): + """计算Gram矩阵""" + batch_size, channels, height, width = tensor.size() + features = tensor.view(batch_size, channels, height * width) + gram = torch.bmm(features, features.transpose(1, 2)) + return gram / (channels * height * width) + + def extract(self, image: np.ndarray) -> Dict[str, Any]: + """ + 从图像中提取风格特征。 + + Args: + image: 输入图像,形状为 (H, W, C),值域 [0, 255] + + Returns: + 包含Gram矩阵等高级风格特征的字典 + """ + # 转换numpy数组为PIL图像 + if isinstance(image, np.ndarray): + image = Image.fromarray(image.astype('uint8')) + + # 预处理 + input_tensor = self.preprocess(image).unsqueeze(0).to(self.device) + + # 注册钩子并前向传播 + self._register_hooks() + self.model(input_tensor) + self._remove_hooks() + + # 计算Gram矩阵 + style_features = {} + for layer_name in self.style_layers: + if layer_name in self.features: + style_features[layer_name] = self.gram_matrix(self.features[layer_name]) + + # 清理 + self.features.clear() + + return style_features \ No newline at end of file diff --git a/src/style_transfer_ai1/core/neural_image_synthesizer.py b/src/style_transfer_ai1/core/neural_image_synthesizer.py new file mode 100644 index 0000000..6ed3e62 --- /dev/null +++ b/src/style_transfer_ai1/core/neural_image_synthesizer.py @@ -0,0 +1,207 @@ +""" +基于PyTorch的高效图像合成器。 +使用GPU加速和现代优化技术。 +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import transforms +from PIL import Image +import numpy as np +from typing import Dict, Any, Optional +import time +from tqdm import tqdm + + +class NeuralImageSynthesizer: + """ + 使用神经网络特征进行高效图像合成的合成器。 + """ + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 优化参数 + self.content_weight = config.get('content_weight', 1e0) + self.style_weight = config.get('style_weight', 1e6) + self.tv_weight = config.get('tv_weight', 1e-6) # 全变分正则化 + + self.model = None + self.content_features = None + self.style_features = None + + def set_content_features(self, content_tensor: torch.Tensor, content_features: Dict[str, torch.Tensor]): + """设置内容特征""" + self.content_tensor = content_tensor + self.content_features = content_features + + def set_style_features(self, style_features: Dict[str, torch.Tensor]): + """设置风格特征""" + self.style_features = style_features + + def _calculate_content_loss(self, target_features, content_features): + """计算内容损失""" + content_loss = 0 + for layer in content_features: + target_feat = target_features[layer] + content_feat = content_features[layer] + content_loss += torch.nn.functional.mse_loss(target_feat, content_feat) + return content_loss + + def _calculate_style_loss(self, target_features, style_features): + """计算风格损失""" + style_loss = 0 + for layer in style_features: + target_feat = target_features[layer] + target_gram = self._gram_matrix(target_feat) + style_gram = style_features[layer] + style_loss += torch.nn.functional.mse_loss(target_gram, style_gram) + return style_loss + + def _gram_matrix(self, tensor): + """计算Gram矩阵""" + batch_size, channels, height, width = tensor.size() + features = tensor.view(batch_size, channels, height * width) + gram = torch.bmm(features, features.transpose(1, 2)) + return gram / (channels * height * width) + + def _total_variation_loss(self, image): + """全变分正则化,减少噪声""" + tv_h = torch.mean(torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])) + tv_w = torch.mean(torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])) + return tv_h + tv_w + + def synthesize(self, + model: nn.Module, + content_layers: List[str], + style_layers: List[str], + num_steps: int = 1000, + learning_rate: float = 0.01, + progress_callback: Optional[callable] = None) -> np.ndarray: + """ + 合成风格化图像。 + + Args: + model: 特征提取模型 + content_layers: 内容层列表 + style_layers: 风格层列表 + num_steps: 迭代次数 + learning_rate: 学习率 + progress_callback: 进度回调函数 + + Returns: + 合成图像,形状为 (H, W, C),值域 [0, 255] + """ + self.model = model + + # 初始化目标图像(从内容图像开始) + target_image = self.content_tensor.clone().requires_grad_(True) + + # 使用Adam优化器(比L-BFGS-B更适合这个问题) + optimizer = optim.Adam([target_image], lr=learning_rate) + + # 学习率调度器 + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=num_steps // 3, gamma=0.5) + + # 进度条 + pbar = tqdm(total=num_steps, desc="风格迁移") + + best_loss = float('inf') + best_image = None + + for step in range(num_steps): + optimizer.zero_grad() + + # 获取目标图像特征 + target_features = self._get_features(target_image, model, content_layers + style_layers) + + # 计算各项损失 + content_loss = self._calculate_content_loss( + {k: v for k, v in target_features.items() if k in content_layers}, + self.content_features + ) + + style_loss = self._calculate_style_loss( + {k: v for k, v in target_features.items() if k in style_layers}, + self.style_features + ) + + tv_loss = self._total_variation_loss(target_image) + + # 总损失 + total_loss = (self.content_weight * content_loss + + self.style_weight * style_loss + + self.tv_weight * tv_loss) + + # 反向传播 + total_loss.backward() + optimizer.step() + scheduler.step() + + # 限制像素值范围 + with torch.no_grad(): + target_image.data.clamp_(0, 1) + + # 保存最佳结果 + if total_loss.item() < best_loss: + best_loss = total_loss.item() + best_image = target_image.detach().clone() + + # 更新进度 + pbar.set_postfix({ + 'Loss': f'{total_loss.item():.4f}', + 'Content': f'{content_loss.item():.4f}', + 'Style': f'{style_loss.item():.4f}' + }) + pbar.update(1) + + if progress_callback and step % 10 == 0: + progress_callback(step, num_steps, total_loss.item()) + + pbar.close() + + # 使用最佳结果 + final_image = best_image.squeeze(0).cpu() + + # 后处理:反标准化 + mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + final_image = final_image * std + mean + final_image = torch.clamp(final_image, 0, 1) + + # 转换为numpy数组 + result_np = (final_image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + + return result_np + + def _get_features(self, image, model, layers): + """获取指定层的特征""" + features = {} + + def hook_fn(module, input, output, layer_name): + features[layer_name] = output + + hooks = [] + layer_idx = 0 + + # 注册钩子 + for name, layer in model.named_children(): + if isinstance(layer, nn.Conv2d): + layer_name = f"conv{layer_idx // 2 + 1}_{(layer_idx % 2) + 1}" + if layer_name in layers: + hook = layer.register_forward_hook( + lambda m, i, o, ln=layer_name: hook_fn(m, i, o, ln) + ) + hooks.append(hook) + layer_idx += 1 + + # 前向传播 + model(image) + + # 移除钩子 + for hook in hooks: + hook.remove() + + return features \ No newline at end of file diff --git a/src/style_transfer_ai1/utils/__init__.py b/src/style_transfer_ai1/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/style_transfer_ai1/utils/__pycache__/__init__.cpython-311.pyc b/src/style_transfer_ai1/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..be66553 Binary files /dev/null and b/src/style_transfer_ai1/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/style_transfer_ai1/utils/__pycache__/image_utils.cpython-311.pyc b/src/style_transfer_ai1/utils/__pycache__/image_utils.cpython-311.pyc new file mode 100644 index 0000000..af3fa68 Binary files /dev/null and b/src/style_transfer_ai1/utils/__pycache__/image_utils.cpython-311.pyc differ diff --git a/src/style_transfer_ai1/utils/image_utils.py b/src/style_transfer_ai1/utils/image_utils.py new file mode 100644 index 0000000..e277e05 --- /dev/null +++ b/src/style_transfer_ai1/utils/image_utils.py @@ -0,0 +1,38 @@ +""" +图像处理工具模块。 + +提供加载和保存图像的便捷函数。 +""" + +import numpy as np +from PIL import Image +from typing import Tuple + + +def load_image(path: str) -> np.ndarray: + """ + 加载图像。 + + Args: + path: 图像文件路径。 + + Returns: + 图像数据,形状为 (H, W, C)。 + """ + with Image.open(path) as img: + # 转换为RGB,避免灰度图或RGBA图导致问题 + if img.mode != 'RGB': + img = img.convert('RGB') + return np.array(img) + + +def save_image(image: np.ndarray, path: str): + """ + 保存图像。 + + Args: + image: 图像数据,形状为 (H, W, C)。 + path: 保存路径。 + """ + img = Image.fromarray(image) + img.save(path) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..49524cb --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,22 @@ +from PIL import Image +import numpy as np +from pathlib import Path + +# 获取项目根目录 +current_file = Path(__file__) # D:\1\1\style_transfer_ai1\tests\test_core.py +project_root = current_file.parent.parent # 上两级:tests -> style_transfer_ai1 + +# 图片路径 +img_path = project_root / "data" / "raw" / "content.jpg" + +print(f"📁 正在加载图片: {img_path}") + +# 加载图像 +img = Image.open(img_path) +img_array = np.array(img) + +# 保存原图(测试是否能写入) +output_path = project_root / "data" / "processed" / "test_result.jpg" +Image.fromarray(img_array).save(output_path) + +print("✅ 测试成功!") \ No newline at end of file