commit 5aea9c39c0894c0a0175e2c5d563d6bb44e8d617
Author: 张嘉欣 <1761133400@qq.com>
Date: Mon Dec 29 17:42:47 2025 +0800
项目初始化:基于深度学习的AI风格迁移系统
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..359bb53
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..c920e16
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..7d9e181
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/style_transfer_ai.iml b/.idea/style_transfer_ai.iml
new file mode 100644
index 0000000..e89a2b8
--- /dev/null
+++ b/.idea/style_transfer_ai.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..77b61a5
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..146123b
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,24 @@
+# .pre-commit-config.yaml
+repos:
+ - repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+ args: [--line-length=88]
+
+ - repo: https://github.com/pycqa/flake8
+ rev: 5.0.4
+ hooks:
+ - id: flake8
+ args: [--max-line-length=88]
+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+
+ - repo: https://github.com/python/mypy
+ rev: v1.3.0
+ hooks:
+ - id: mypy
+ args: [--ignore-missing-imports]
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..d294955
--- /dev/null
+++ b/README.md
@@ -0,0 +1,73 @@
+# 🎨 AI 画风迁移器 (StyleTransferAI)
+
+一个新颖、轻量级的Python项目,旨在通过神经风格迁移算法,实现快速有趣的艺术风格转换效果。
+
+## 🚀 快速开始
+
+1. **克隆项目**
+ ```bash
+ git clone https://github.com/yourname/style_transfer_ai.git
+ cd style_transfer_ai1
+ ```
+
+2. **创建虚拟环境并安装依赖**
+ ```bash
+ python -m venv venv
+ source venv/bin/activate # Linux/Mac
+ # venv\Scripts\activate # Windows
+ pip install -r requirements.txt
+ ```
+
+3. **准备图像**
+ 在 `data/raw/` 目录下放置你的内容图 (`content.jpg`) 和风格图 (`style.jpg`)。
+
+4. **运行程序**
+ ```bash
+ python scripts/run_style_transfer.py
+ ```
+
+5. **查看结果**
+ 结果将保存在 `data/processed/result_时间戳.jpg`。
+
+## 💡 项目特点
+
+- **高质量风格迁移**: 基于VGG16深度学习模型,生成高质量艺术效果。
+- **速度快**: CPU优化,无需GPU即可快速运行。
+- **可视化监控**: 实时显示训练损失,生成详细损失曲线图。
+- **高度可定制**: 修改 `configs/default.yaml` 或核心代码即可调整风格迁移效果。
+- **多模式训练**: 提供快速、扩展、高级三种训练模式。
+
+## 📂 项目结构
+
+```
+style_transfer_ai/
+│
+├── configs/ # 配置文件目录
+│ └── default.yaml # 默认超参数配置(学习率、迭代次数、权重等)
+│
+├── data/ # 数据目录
+│ ├── raw/ # 原始输入图像
+│ │ ├── content.jpg # 内容图像(用户自定义)
+│ │ └── style.jpg # 风格图像(用户自定义)
+│ └── processed/ # 输出结果目录(自动生成)
+│ └── result_20251202_2227.jpg # 示例输出(带时间戳)
+│
+├── models/ # 模型相关代码
+│ └── vgg16_extractor.py # VGG16特征提取器封装
+│
+├── scripts/ # 可执行脚本
+│ └── run_style_transfer.py # 主程序入口
+│
+├── src/ # 核心源代码
+│ ├── style_transfer.py # 风格迁移核心算法实现
+│ ├── loss.py # 内容损失、风格损失计算
+│ ├── utils.py # 图像预处理、后处理、可视化工具
+│ └── trainer.py # 训练循环与优化器管理
+│
+├── outputs/ # 日志与可视化输出
+│ └── loss_curve_*.png # 损失变化曲线图(按运行时间命名)
+│
+├── requirements.txt # Python依赖列表
+├── README.md # 项目说明文档(即本文件)
+└── .gitignore # Git忽略规则
+```
\ No newline at end of file
diff --git a/configs/default.yaml b/configs/default.yaml
new file mode 100644
index 0000000..f25384e
--- /dev/null
+++ b/configs/default.yaml
@@ -0,0 +1,36 @@
+# configs/stable_config.yaml
+# 稳定训练配置
+
+# 模型配置
+model:
+ model_type: "vgg16"
+ image_size: 192
+ content_layers: ["19"] # conv4_2
+ style_layers: ["0", "5", "10", "17", "24"] # 全部5个层
+ content_weight: 1e4 # 增加内容权重
+ style_weight: 5e4 # 降低风格权重
+ learning_rate: 0.005 # 低学习率稳定训练
+
+# 路径配置
+paths:
+ content_image: "data/raw/content.jpg"
+ style_image: "data/raw/tem.jpg"
+ output_dir: "data/processed"
+
+# 训练配置
+training:
+ num_steps: 150 # 减少步数
+ log_interval: 15
+ tv_weight: 1e-5
+
+# 多尺度配置(先禁用,稳定后再启用)
+multi_scale:
+ enabled: false
+ scales: [192]
+ scale_weights: [1.0]
+
+# 性能配置
+performance:
+ use_gpu: false
+ batch_size: 1
+ precision: "float32"
\ No newline at end of file
diff --git a/data/processed/extended_result_133911.jpg b/data/processed/extended_result_133911.jpg
new file mode 100644
index 0000000..3cf8281
Binary files /dev/null and b/data/processed/extended_result_133911.jpg differ
diff --git a/data/processed/extended_result_174121.jpg b/data/processed/extended_result_174121.jpg
new file mode 100644
index 0000000..4f086d2
Binary files /dev/null and b/data/processed/extended_result_174121.jpg differ
diff --git a/data/processed/extended_result_220811.jpg b/data/processed/extended_result_220811.jpg
new file mode 100644
index 0000000..5f400fe
Binary files /dev/null and b/data/processed/extended_result_220811.jpg differ
diff --git a/data/processed/tem.jpg b/data/processed/tem.jpg
new file mode 100644
index 0000000..7efc2c1
Binary files /dev/null and b/data/processed/tem.jpg differ
diff --git a/data/raw/content.jpg b/data/raw/content.jpg
new file mode 100644
index 0000000..d67a5a7
Binary files /dev/null and b/data/raw/content.jpg differ
diff --git a/data/raw/tem.jpg b/data/raw/tem.jpg
new file mode 100644
index 0000000..203618f
Binary files /dev/null and b/data/raw/tem.jpg differ
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..ace7066
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,20 @@
+[build-system]
+requires = ["setuptools>=45", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "style-transfer-ai"
+version = "0.1.0"
+description = "A style transfer AI tool using neural networks"
+authors = [
+ {name = "Your Name", email = "your.email@example.com"},
+]
+dependencies = [
+ "torch>=2.0.0",
+ "torchvision>=0.15.0",
+ "numpy>=1.24.0",
+ "Pillow>=9.0.0",
+ "scipy>=1.10.0",
+ "PyYAML>=6.0.0",
+]
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..37b46a9
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+numpy>=1.24.0
+Pillow>=9.0.0
+scipy>=1.10.0
+PyYAML>=6.0.0
\ No newline at end of file
diff --git a/scripts/run_style_transfer.py b/scripts/run_style_transfer.py
new file mode 100644
index 0000000..d08e37d
--- /dev/null
+++ b/scripts/run_style_transfer.py
@@ -0,0 +1,263 @@
+# scripts/extended_training.py
+import warnings
+
+warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import models, transforms
+from PIL import Image
+from pathlib import Path
+import yaml
+import time
+import numpy as np
+
+
+print("=" * 60)
+
+
+def main():
+ # 加载配置
+ config_path = Path(__file__).parent.parent / "configs" / "default.yaml"
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+ # ========== 关键修复:大幅提高风格影响 ==========
+ config['model']['image_size'] = 256
+ config['model']['content_weight'] = 1e3 # 降低内容权重
+ config['model']['style_weight'] = 1e7 # 大幅提高风格权重
+ config['model']['learning_rate'] = 0.003
+ config['training']['num_steps'] = 400
+ config['multi_scale']['enabled'] = False
+
+ print("📋 长时间训练参数:")
+ print(f" - 图像尺寸: 256x256")
+ print(f" - 内容权重: 1e3 (风格权重的1/10000)")
+ print(f" - 风格权重: 1e7 (内容权重的10000倍)")
+ print(f" - 权重比例: 1:10000 (强调风格)")
+ print(f" - 学习率: 0.003")
+ print(f" - 训练步数: 400 (约5分钟)")
+
+ # 设备设置
+ device = torch.device("cpu")
+ torch.set_num_threads(2)
+
+ # 加载图像
+ project_root = Path(__file__).parent.parent
+ content_path = project_root / config['paths']['content_image']
+ style_path = project_root / config['paths']['style_image']
+ output_dir = project_root / config['paths']['output_dir']
+
+ print(f"\n📷 加载图像...")
+ content_img = Image.open(content_path).convert('RGB')
+ style_img = Image.open(style_path).convert('RGB')
+
+ image_size = 256
+ content_img = content_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
+ style_img = style_img.resize((image_size, image_size), Image.Resampling.LANCZOS)
+
+ print(f" 调整到: {image_size}x{image_size}")
+
+ # 预处理(移除标准化)
+ preprocess = transforms.Compose([
+ transforms.ToTensor(),
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 注释掉
+ ])
+
+ content_tensor = preprocess(content_img).unsqueeze(0).to(device)
+ style_tensor = preprocess(style_img).unsqueeze(0).to(device)
+ target_tensor = content_tensor.clone().requires_grad_(True)
+
+ # 加载VGG16模型
+ print("\n🔧 加载VGG16模型...")
+ start_time = time.time()
+ vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
+
+ for param in vgg.parameters():
+ param.requires_grad_(False)
+
+ print(f"✅ 模型加载完成 ({time.time() - start_time:.2f}秒)")
+
+ # 定义层索引
+ content_layers = ["21"] # conv4_2
+ style_layers = ["0", "5", "10", "19", "28"] # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
+
+ # 特征提取函数
+ def get_features_fixed(x, layers):
+ features = {}
+ x_tmp = x
+ layers_int = set([int(l) for l in layers])
+
+ for idx, (name, layer) in enumerate(vgg._modules.items()):
+ x_tmp = layer(x_tmp)
+ if idx in layers_int:
+ features[str(idx)] = x_tmp
+
+ return features
+
+ # 提取特征
+ print("\n📊 提取特征...")
+ start_time = time.time()
+
+ # 标准化输入以匹配VGG的预训练权重
+ def normalize_for_vgg(tensor):
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
+ return (tensor - mean) / std
+
+ content_features = get_features_fixed(normalize_for_vgg(content_tensor), content_layers)
+ style_features = get_features_fixed(normalize_for_vgg(style_tensor), style_layers)
+
+ # 修改Gram矩阵计算(关键修复)
+ def gram_matrix_fixed(tensor):
+ b, c, h, w = tensor.size()
+
+ # 展平特征
+ tensor_flat = tensor.view(b, c, -1) # [b, c, h*w]
+
+ # 计算Gram矩阵
+ gram = torch.bmm(tensor_flat, tensor_flat.transpose(1, 2)) # [b, c, c]
+
+ # 关键修改:只除以h*w,而不是c*h*w(原始论文做法)
+ gram = gram / (h * w)
+
+ return gram
+
+ # 预计算风格Gram矩阵
+ style_grams = {}
+ for layer, feat in style_features.items():
+ style_grams[layer] = gram_matrix_fixed(feat)
+ print(f" 层 {layer} Gram矩阵形状: {style_grams[layer].shape}")
+
+ print(f"✅ 特征提取完成 ({time.time() - start_time:.2f}秒)")
+
+ # 优化器设置
+ learning_rate = 0.003 # 使用更高的学习率
+ optimizer = optim.Adam([target_tensor], lr=learning_rate)
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
+
+ # 训练参数
+ content_weight = 1e3
+ style_weight = 1e7 # 大幅提高
+ num_steps = 400
+ log_interval = 40
+
+ print(f"\n🎨 开始长时间训练 (约5分钟)...")
+ print(f" 优化器: Adam (初始lr={learning_rate})")
+ print(f" 内容权重: {content_weight}")
+ print(f" 风格权重: {style_weight} (非常高,强调风格转移)")
+ print(f" 权重比例: 1:{int(style_weight / content_weight)}")
+
+ start_time = time.time()
+
+ # 记录损失历史
+ loss_history = []
+ content_loss_history = []
+ style_loss_history = []
+
+ for step in range(1, num_steps + 1):
+ optimizer.zero_grad()
+
+ # 提取目标特征(需要标准化)
+ target_normalized = normalize_for_vgg(target_tensor)
+ target_features = get_features_fixed(target_normalized, content_layers + style_layers)
+
+ # 计算内容损失
+ content_loss = torch.tensor(0.0, device=device)
+ for layer in content_layers:
+ if layer in target_features and layer in content_features:
+ diff = target_features[layer] - content_features[layer]
+ content_loss += torch.mean(diff ** 2)
+
+ # 计算风格损失
+ style_loss = torch.tensor(0.0, device=device)
+ for layer in style_layers:
+ if layer in target_features and layer in style_grams:
+ target_gram = gram_matrix_fixed(target_features[layer])
+ style_gram = style_grams[layer]
+ diff = target_gram - style_gram
+ style_loss += torch.mean(diff ** 2)
+
+ # 总损失
+ total_loss = content_weight * content_loss + style_weight * style_loss
+
+ # 反向传播
+ total_loss.backward()
+
+ # 梯度裁剪(提高限制)
+ torch.nn.utils.clip_grad_norm_([target_tensor], max_norm=10.0)
+
+ optimizer.step()
+
+ # 更新学习率
+ if step % 100 == 0:
+ scheduler.step()
+ current_lr = optimizer.param_groups[0]['lr']
+ print(f"📉 第 {step} 步: 学习率降低到 {current_lr:.6f}")
+
+ # 限制像素值
+ with torch.no_grad():
+ target_tensor.data.clamp_(0, 1)
+
+ # 记录损失
+ loss_history.append(total_loss.item())
+ content_loss_history.append(content_loss.item())
+ style_loss_history.append(style_loss.item())
+
+ # 定期输出
+ if step % log_interval == 0 or step == num_steps:
+ elapsed = time.time() - start_time
+ steps_per_second = step / elapsed if elapsed > 0 else 0
+ remaining_steps = num_steps - step
+ estimated_remaining = remaining_steps / steps_per_second if steps_per_second > 0 else 0
+
+ # 计算风格损失贡献
+ style_contribution = style_weight * style_loss.item()
+ content_contribution = content_weight * content_loss.item()
+ style_ratio = style_contribution / (style_contribution + content_contribution) * 100
+
+ print(f"⏳ 第 {step:3d}/{num_steps} | "
+ f"总损失: {total_loss.item():8.2f} | "
+ f"内容: {content_loss.item():6.4f} | "
+ f"风格: {style_loss.item():6.4f} | "
+ f"风格贡献: {style_ratio:.1f}% | "
+ f"用时: {elapsed:.1f}s | "
+ f"剩余: {estimated_remaining:.1f}s")
+
+ total_time = time.time() - start_time
+
+ # 分析损失变化
+ print(f"\n📈 损失分析:")
+ print(f" 总损失变化: {loss_history[0]:.2f} → {loss_history[-1]:.2f}")
+ print(f" 内容损失变化: {content_loss_history[0]:.4f} → {content_loss_history[-1]:.4f}")
+ print(f" 风格损失变化: {style_loss_history[0]:.4f} → {style_loss_history[-1]:.4f}")
+
+ # 保存结果
+ postprocess = transforms.Compose([
+ # transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
+ # std=[1/0.229, 1/0.224, 1/0.225]), # 如果之前没标准化,这里也不需要
+ transforms.ToPILImage()
+ ])
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ with torch.no_grad():
+ result_img = postprocess(target_tensor.squeeze(0).cpu())
+
+ timestamp = time.strftime("%H%M%S")
+ output_path = output_dir / f"extended_result_{timestamp}.jpg"
+ result_img.save(output_path, quality=95, optimize=True)
+
+ print(f"\n🎉 训练完成!")
+ print(f" 总用时: {total_time:.1f}秒")
+ print(f" 结果保存至: {output_path}")
+
+ # 尝试显示结果
+ try:
+ result_img.show(title="风格迁移结果")
+ except:
+ print(" ℹ️ 请手动打开图像查看")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/style_transfer_ai1/core/__init__.py b/src/style_transfer_ai1/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/style_transfer_ai1/core/__pycache__/__init__.cpython-311.pyc b/src/style_transfer_ai1/core/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..970f3df
Binary files /dev/null and b/src/style_transfer_ai1/core/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/style_transfer_ai1/core/__pycache__/feature_extractor.cpython-311.pyc b/src/style_transfer_ai1/core/__pycache__/feature_extractor.cpython-311.pyc
new file mode 100644
index 0000000..8a6c53f
Binary files /dev/null and b/src/style_transfer_ai1/core/__pycache__/feature_extractor.cpython-311.pyc differ
diff --git a/src/style_transfer_ai1/core/__pycache__/image_synthesizer.cpython-311.pyc b/src/style_transfer_ai1/core/__pycache__/image_synthesizer.cpython-311.pyc
new file mode 100644
index 0000000..303d064
Binary files /dev/null and b/src/style_transfer_ai1/core/__pycache__/image_synthesizer.cpython-311.pyc differ
diff --git a/src/style_transfer_ai1/core/feature_extractor.py b/src/style_transfer_ai1/core/feature_extractor.py
new file mode 100644
index 0000000..753a3d7
--- /dev/null
+++ b/src/style_transfer_ai1/core/feature_extractor.py
@@ -0,0 +1,103 @@
+"""
+风格特征提取器模块 - 基础版本和传统方法。
+"""
+
+import numpy as np
+from typing import Tuple, Dict, Any
+from abc import ABC, abstractmethod
+
+
+class BaseFeatureExtractor(ABC):
+ """特征提取器基类"""
+
+ @abstractmethod
+ def extract(self, image: np.ndarray) -> Dict[str, Any]:
+ """提取特征抽象方法"""
+ pass
+
+ @abstractmethod
+ def get_features(self) -> Dict[str, Any]:
+ """获取特征抽象方法"""
+ pass
+
+
+class StyleFeatureExtractor(BaseFeatureExtractor):
+ """
+ 传统风格特征提取器。
+ 通过计算图像的直方图和局部纹理特征来定义"风格"。
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ """
+ 初始化提取器。
+
+ Args:
+ config: 配置字典,包含特征提取参数。
+ """
+ self.config = config
+ self.style_features: Dict[str, Any] = {}
+
+ def extract(self, image: np.ndarray) -> Dict[str, Any]:
+ """
+ 从图像中提取风格特征。
+
+ Args:
+ image: 输入图像,形状为 (H, W, C),值域 [0, 255]。
+
+ Returns:
+ 一个包含各种风格特征的字典。
+ """
+ # 确保图像在[0, 1]范围
+ image_normalized = image.astype(np.float32) / 255.0
+
+ features = {}
+
+ # 1. 颜色直方图 (RGB通道)
+ hist_r, _ = np.histogram(image_normalized[:, :, 0].ravel(), bins=32, range=(0, 1))
+ hist_g, _ = np.histogram(image_normalized[:, :, 1].ravel(), bins=32, range=(0, 1))
+ hist_b, _ = np.histogram(image_normalized[:, :, 2].ravel(), bins=32, range=(0, 1))
+ features['color_histogram'] = np.concatenate([hist_r, hist_g, hist_b])
+
+ # 2. 局部纹理特征
+ diff_h = np.abs(np.diff(image_normalized, axis=1)).mean()
+ diff_v = np.abs(np.diff(image_normalized, axis=0)).mean()
+ features['texture_strength'] = np.array([diff_h, diff_v])
+
+ # 3. 平均颜色
+ features['mean_color'] = image_normalized.mean(axis=(0, 1))
+
+ # 4. 颜色标准差
+ features['color_std'] = image_normalized.std(axis=(0, 1))
+
+ self.style_features = features
+ return features
+
+ def get_features(self) -> Dict[str, Any]:
+ """
+ 获取最近一次提取的风格特征。
+
+ Returns:
+ 风格特征字典。
+ """
+ return self.style_features.copy()
+
+
+# 工厂函数,便于创建不同类型的提取器
+def create_feature_extractor(extractor_type: str, config: Dict[str, Any]) -> BaseFeatureExtractor:
+ """
+ 创建特征提取器工厂函数。
+
+ Args:
+ extractor_type: 提取器类型 ('traditional' 或 'neural')
+ config: 配置字典
+
+ Returns:
+ 特征提取器实例
+ """
+ if extractor_type == 'traditional':
+ return StyleFeatureExtractor(config)
+ elif extractor_type == 'neural':
+ from .neural_feature_extractor import NeuralStyleFeatureExtractor
+ return NeuralStyleFeatureExtractor(config)
+ else:
+ raise ValueError(f"不支持的提取器类型: {extractor_type}")
\ No newline at end of file
diff --git a/src/style_transfer_ai1/core/image_synthesizer.py b/src/style_transfer_ai1/core/image_synthesizer.py
new file mode 100644
index 0000000..8501906
--- /dev/null
+++ b/src/style_transfer_ai1/core/image_synthesizer.py
@@ -0,0 +1,171 @@
+"""
+图像合成器模块 - 基础版本和传统方法。
+"""
+
+import numpy as np
+from typing import Tuple, Dict, Any
+from scipy.optimize import minimize
+import warnings
+from abc import ABC, abstractmethod
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+class BaseImageSynthesizer(ABC):
+ """图像合成器基类"""
+
+ @abstractmethod
+ def set_content(self, image: np.ndarray):
+ """设置内容图像"""
+ pass
+
+ @abstractmethod
+ def set_style_features(self, features: Dict[str, Any]):
+ """设置风格特征"""
+ pass
+
+ @abstractmethod
+ def synthesize(self, initial_image: np.ndarray = None, **kwargs) -> np.ndarray:
+ """合成图像"""
+ pass
+
+
+class ImageSynthesizer(BaseImageSynthesizer):
+ """
+ 传统图像合成器,负责生成最终的风格化图像。
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ """
+ 初始化合成器。
+
+ Args:
+ config: 配置字典。
+ """
+ self.config = config
+ self.content_image: np.ndarray = None
+ self.style_features: Dict[str, Any] = {}
+ self.synthesized_image: np.ndarray = None
+
+ def set_content(self, image: np.ndarray):
+ """
+ 设置内容图像。
+
+ Args:
+ image: 内容图像,形状为 (H, W, C)。
+ """
+ self.content_image = image.astype(np.float32) / 255.0
+
+ def set_style_features(self, features: Dict[str, Any]):
+ """
+ 设置风格特征。
+
+ Args:
+ features: 由 StyleFeatureExtractor 提取的特征字典。
+ """
+ self.style_features = features
+
+ def synthesize(self, initial_image: np.ndarray = None, max_iter: int = 200) -> np.ndarray:
+ """
+ 合成风格化图像。
+
+ Args:
+ initial_image: 初始图像,如果为None,则使用内容图像。
+ max_iter: 最大迭代次数。
+
+ Returns:
+ 合成后的风格化图像,形状为 (H, W, C),值域 [0, 255]。
+ """
+ if initial_image is None:
+ initial_image = self.content_image.copy()
+
+ # 将图像展平为一维向量以便于优化
+ h, w, c = self.content_image.shape
+ x0 = initial_image.ravel()
+
+ # 定义损失函数
+ def loss_function(x):
+ # 重塑回图像形状
+ img = x.reshape((h, w, c))
+
+ # 1. 内容损失:保持与原图相似
+ content_loss = np.mean((img - self.content_image) ** 2)
+
+ # 2. 风格损失:匹配风格特征
+ style_loss = self._calculate_style_loss(img)
+
+ # 总损失
+ total_loss = content_loss + self.config.get('style_weight', 1.0) * style_loss
+
+ return total_loss
+
+ # 执行优化
+ result = minimize(loss_function, x0, method='L-BFGS-B',
+ options={'maxiter': max_iter, 'disp': False})
+
+ # 重塑结果并转换回[0, 255]
+ synthesized = result.x.reshape((h, w, c))
+ synthesized = np.clip(synthesized, 0, 1) * 255
+ self.synthesized_image = synthesized.astype(np.uint8)
+
+ return self.synthesized_image
+
+ def _calculate_style_loss(self, image: np.ndarray) -> float:
+ """
+ 计算当前图像与目标风格特征之间的损失。
+
+ Args:
+ image: 当前合成的图像。
+
+ Returns:
+ 风格损失值。
+ """
+ loss = 0.0
+
+ # 1. 颜色直方图损失
+ hist_r, _ = np.histogram(image[:, :, 0].ravel(), bins=32, range=(0, 1))
+ hist_g, _ = np.histogram(image[:, :, 1].ravel(), bins=32, range=(0, 1))
+ hist_b, _ = np.histogram(image[:, :, 2].ravel(), bins=32, range=(0, 1))
+ current_hist = np.concatenate([hist_r, hist_g, hist_b])
+ target_hist = self.style_features['color_histogram']
+ # 使用KL散度作为距离度量
+ epsilon = 1e-10
+ kl_div = np.sum(target_hist * np.log((target_hist + epsilon) / (current_hist + epsilon)))
+ loss += kl_div
+
+ # 2. 纹理强度损失
+ diff_h = np.abs(np.diff(image, axis=1)).mean()
+ diff_v = np.abs(np.diff(image, axis=0)).mean()
+ current_texture = np.array([diff_h, diff_v])
+ target_texture = self.style_features['texture_strength']
+ texture_loss = np.mean((current_texture - target_texture) ** 2)
+ loss += texture_loss
+
+ # 3. 平均颜色损失
+ current_mean = image.mean(axis=(0, 1))
+ target_mean = self.style_features['mean_color']
+ mean_loss = np.mean((current_mean - target_mean) ** 2)
+ loss += mean_loss
+
+ return loss
+
+
+# 工厂函数
+def create_image_synthesizer(synthesizer_type: str, config: Dict[str, Any]) -> BaseImageSynthesizer:
+ """
+ 创建图像合成器工厂函数。
+
+ Args:
+ synthesizer_type: 合成器类型 ('traditional' 或 'neural')
+ config: 配置字典
+
+ Returns:
+ 图像合成器实例
+ """
+ if synthesizer_type == 'traditional':
+ return ImageSynthesizer(config)
+ elif synthesizer_type == 'neural':
+ from .neural_image_synthesizer import NeuralImageSynthesizer
+ return NeuralImageSynthesizer(config)
+ else:
+ raise ValueError(f"不支持的合成器类型: {synthesizer_type}")
\ No newline at end of file
diff --git a/src/style_transfer_ai1/core/multi_scale_processor.py b/src/style_transfer_ai1/core/multi_scale_processor.py
new file mode 100644
index 0000000..303a8b2
--- /dev/null
+++ b/src/style_transfer_ai1/core/multi_scale_processor.py
@@ -0,0 +1,82 @@
+"""
+多尺度风格迁移处理器。
+在不同尺度上处理图像以获得更好的结果。
+"""
+
+import torch
+import torch.nn as nn
+from torchvision import transforms
+import numpy as np
+from PIL import Image
+from typing import Dict, Any, List
+import cv2
+
+
+class MultiScaleProcessor:
+ """
+ 多尺度图像处理器。
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.scales = config.get('scales', [512, 256]) # 从大到小
+ self.scale_weights = config.get('scale_weights', [1.0, 0.5])
+
+ def process_multi_scale(self,
+ content_image: np.ndarray,
+ style_image: np.ndarray,
+ synthesizer: Any) -> np.ndarray:
+ """
+ 多尺度风格迁移。
+
+ Args:
+ content_image: 内容图像
+ style_image: 风格图像
+ synthesizer: 图像合成器
+
+ Returns:
+ 合成后的图像
+ """
+ current_result = None
+
+ for i, scale in enumerate(self.scales):
+ print(f"处理尺度: {scale}x{scale}")
+
+ # 调整图像尺寸
+ content_scaled = self._resize_image(content_image, scale)
+ style_scaled = self._resize_image(style_image, scale)
+
+ # 如果是第一次迭代,使用内容图像作为初始值
+ # 否则使用上一步的结果(调整到当前尺度)
+ if current_result is None:
+ initial_image = None
+ else:
+ initial_image = self._resize_image(current_result, scale)
+
+ # 在当前尺度上合成
+ # 这里需要根据你的具体合成器接口调整
+ current_result = synthesizer.synthesize(
+ content_scaled,
+ style_scaled,
+ initial_image=initial_image,
+ scale_weight=self.scale_weights[i]
+ )
+
+ return current_result
+
+ def _resize_image(self, image: np.ndarray, size: int) -> np.ndarray:
+ """调整图像尺寸"""
+ h, w = image.shape[:2]
+
+ if h > w:
+ new_h = size
+ new_w = int(w * size / h)
+ else:
+ new_w = size
+ new_h = int(h * size / w)
+
+ # 使用高质量的重采样方法
+ pil_image = Image.fromarray(image)
+ resized = pil_image.resize((new_w, new_h), Image.LANCZOS)
+
+ return np.array(resized)
\ No newline at end of file
diff --git a/src/style_transfer_ai1/core/neural_feature_extractor.py b/src/style_transfer_ai1/core/neural_feature_extractor.py
new file mode 100644
index 0000000..01db088
--- /dev/null
+++ b/src/style_transfer_ai1/core/neural_feature_extractor.py
@@ -0,0 +1,126 @@
+"""
+基于深度学习的风格特征提取器。
+使用预训练的CNN网络提取高级风格特征。
+"""
+
+import torch
+import torch.nn as nn
+import torchvision.models as models
+import torchvision.transforms as transforms
+from PIL import Image
+import numpy as np
+from typing import Dict, List, Any
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class NeuralStyleFeatureExtractor:
+ """
+ 使用VGG网络提取风格特征的提取器。
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.model = self._load_model()
+ self.preprocess = self._get_preprocess()
+
+ # 定义用于风格提取的层
+ self.style_layers = {
+ 'vgg19': ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
+ 'vgg16': ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']
+ }[config.get('model_type', 'vgg19')]
+
+ self.features = {}
+
+ def _load_model(self):
+ """加载预训练的VGG模型"""
+ model_type = self.config.get('model_type', 'vgg19')
+ if model_type == 'vgg19':
+ model = models.vgg19(pretrained=True).features
+ else:
+ model = models.vgg16(pretrained=True).features
+
+ # 冻结参数
+ for param in model.parameters():
+ param.requires_grad = False
+
+ return model.to(self.device).eval()
+
+ def _get_preprocess(self):
+ """获取图像预处理流程"""
+ return transforms.Compose([
+ transforms.Resize(self.config.get('image_size', 512)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+ ])
+
+ def _hook_layer(self, layer_name):
+ """注册钩子来获取中间层输出"""
+
+ def hook(module, input, output):
+ self.features[layer_name] = output
+
+ return hook
+
+ def _register_hooks(self):
+ """为风格层注册钩子"""
+ self.hooks = []
+ layer_idx = 0
+
+ for name, layer in self.model.named_children():
+ if isinstance(layer, nn.Conv2d):
+ layer_name = f"conv{layer_idx // 2 + 1}_{(layer_idx % 2) + 1}"
+ if layer_name in self.style_layers:
+ hook = layer.register_forward_hook(self._hook_layer(layer_name))
+ self.hooks.append(hook)
+ layer_idx += 1
+
+ def _remove_hooks(self):
+ """移除所有钩子"""
+ for hook in self.hooks:
+ hook.remove()
+
+ def gram_matrix(self, tensor):
+ """计算Gram矩阵"""
+ batch_size, channels, height, width = tensor.size()
+ features = tensor.view(batch_size, channels, height * width)
+ gram = torch.bmm(features, features.transpose(1, 2))
+ return gram / (channels * height * width)
+
+ def extract(self, image: np.ndarray) -> Dict[str, Any]:
+ """
+ 从图像中提取风格特征。
+
+ Args:
+ image: 输入图像,形状为 (H, W, C),值域 [0, 255]
+
+ Returns:
+ 包含Gram矩阵等高级风格特征的字典
+ """
+ # 转换numpy数组为PIL图像
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image.astype('uint8'))
+
+ # 预处理
+ input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
+
+ # 注册钩子并前向传播
+ self._register_hooks()
+ self.model(input_tensor)
+ self._remove_hooks()
+
+ # 计算Gram矩阵
+ style_features = {}
+ for layer_name in self.style_layers:
+ if layer_name in self.features:
+ style_features[layer_name] = self.gram_matrix(self.features[layer_name])
+
+ # 清理
+ self.features.clear()
+
+ return style_features
\ No newline at end of file
diff --git a/src/style_transfer_ai1/core/neural_image_synthesizer.py b/src/style_transfer_ai1/core/neural_image_synthesizer.py
new file mode 100644
index 0000000..6ed3e62
--- /dev/null
+++ b/src/style_transfer_ai1/core/neural_image_synthesizer.py
@@ -0,0 +1,207 @@
+"""
+基于PyTorch的高效图像合成器。
+使用GPU加速和现代优化技术。
+"""
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import transforms
+from PIL import Image
+import numpy as np
+from typing import Dict, Any, Optional
+import time
+from tqdm import tqdm
+
+
+class NeuralImageSynthesizer:
+ """
+ 使用神经网络特征进行高效图像合成的合成器。
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ # 优化参数
+ self.content_weight = config.get('content_weight', 1e0)
+ self.style_weight = config.get('style_weight', 1e6)
+ self.tv_weight = config.get('tv_weight', 1e-6) # 全变分正则化
+
+ self.model = None
+ self.content_features = None
+ self.style_features = None
+
+ def set_content_features(self, content_tensor: torch.Tensor, content_features: Dict[str, torch.Tensor]):
+ """设置内容特征"""
+ self.content_tensor = content_tensor
+ self.content_features = content_features
+
+ def set_style_features(self, style_features: Dict[str, torch.Tensor]):
+ """设置风格特征"""
+ self.style_features = style_features
+
+ def _calculate_content_loss(self, target_features, content_features):
+ """计算内容损失"""
+ content_loss = 0
+ for layer in content_features:
+ target_feat = target_features[layer]
+ content_feat = content_features[layer]
+ content_loss += torch.nn.functional.mse_loss(target_feat, content_feat)
+ return content_loss
+
+ def _calculate_style_loss(self, target_features, style_features):
+ """计算风格损失"""
+ style_loss = 0
+ for layer in style_features:
+ target_feat = target_features[layer]
+ target_gram = self._gram_matrix(target_feat)
+ style_gram = style_features[layer]
+ style_loss += torch.nn.functional.mse_loss(target_gram, style_gram)
+ return style_loss
+
+ def _gram_matrix(self, tensor):
+ """计算Gram矩阵"""
+ batch_size, channels, height, width = tensor.size()
+ features = tensor.view(batch_size, channels, height * width)
+ gram = torch.bmm(features, features.transpose(1, 2))
+ return gram / (channels * height * width)
+
+ def _total_variation_loss(self, image):
+ """全变分正则化,减少噪声"""
+ tv_h = torch.mean(torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]))
+ tv_w = torch.mean(torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]))
+ return tv_h + tv_w
+
+ def synthesize(self,
+ model: nn.Module,
+ content_layers: List[str],
+ style_layers: List[str],
+ num_steps: int = 1000,
+ learning_rate: float = 0.01,
+ progress_callback: Optional[callable] = None) -> np.ndarray:
+ """
+ 合成风格化图像。
+
+ Args:
+ model: 特征提取模型
+ content_layers: 内容层列表
+ style_layers: 风格层列表
+ num_steps: 迭代次数
+ learning_rate: 学习率
+ progress_callback: 进度回调函数
+
+ Returns:
+ 合成图像,形状为 (H, W, C),值域 [0, 255]
+ """
+ self.model = model
+
+ # 初始化目标图像(从内容图像开始)
+ target_image = self.content_tensor.clone().requires_grad_(True)
+
+ # 使用Adam优化器(比L-BFGS-B更适合这个问题)
+ optimizer = optim.Adam([target_image], lr=learning_rate)
+
+ # 学习率调度器
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=num_steps // 3, gamma=0.5)
+
+ # 进度条
+ pbar = tqdm(total=num_steps, desc="风格迁移")
+
+ best_loss = float('inf')
+ best_image = None
+
+ for step in range(num_steps):
+ optimizer.zero_grad()
+
+ # 获取目标图像特征
+ target_features = self._get_features(target_image, model, content_layers + style_layers)
+
+ # 计算各项损失
+ content_loss = self._calculate_content_loss(
+ {k: v for k, v in target_features.items() if k in content_layers},
+ self.content_features
+ )
+
+ style_loss = self._calculate_style_loss(
+ {k: v for k, v in target_features.items() if k in style_layers},
+ self.style_features
+ )
+
+ tv_loss = self._total_variation_loss(target_image)
+
+ # 总损失
+ total_loss = (self.content_weight * content_loss +
+ self.style_weight * style_loss +
+ self.tv_weight * tv_loss)
+
+ # 反向传播
+ total_loss.backward()
+ optimizer.step()
+ scheduler.step()
+
+ # 限制像素值范围
+ with torch.no_grad():
+ target_image.data.clamp_(0, 1)
+
+ # 保存最佳结果
+ if total_loss.item() < best_loss:
+ best_loss = total_loss.item()
+ best_image = target_image.detach().clone()
+
+ # 更新进度
+ pbar.set_postfix({
+ 'Loss': f'{total_loss.item():.4f}',
+ 'Content': f'{content_loss.item():.4f}',
+ 'Style': f'{style_loss.item():.4f}'
+ })
+ pbar.update(1)
+
+ if progress_callback and step % 10 == 0:
+ progress_callback(step, num_steps, total_loss.item())
+
+ pbar.close()
+
+ # 使用最佳结果
+ final_image = best_image.squeeze(0).cpu()
+
+ # 后处理:反标准化
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
+ final_image = final_image * std + mean
+ final_image = torch.clamp(final_image, 0, 1)
+
+ # 转换为numpy数组
+ result_np = (final_image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
+
+ return result_np
+
+ def _get_features(self, image, model, layers):
+ """获取指定层的特征"""
+ features = {}
+
+ def hook_fn(module, input, output, layer_name):
+ features[layer_name] = output
+
+ hooks = []
+ layer_idx = 0
+
+ # 注册钩子
+ for name, layer in model.named_children():
+ if isinstance(layer, nn.Conv2d):
+ layer_name = f"conv{layer_idx // 2 + 1}_{(layer_idx % 2) + 1}"
+ if layer_name in layers:
+ hook = layer.register_forward_hook(
+ lambda m, i, o, ln=layer_name: hook_fn(m, i, o, ln)
+ )
+ hooks.append(hook)
+ layer_idx += 1
+
+ # 前向传播
+ model(image)
+
+ # 移除钩子
+ for hook in hooks:
+ hook.remove()
+
+ return features
\ No newline at end of file
diff --git a/src/style_transfer_ai1/utils/__init__.py b/src/style_transfer_ai1/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/style_transfer_ai1/utils/__pycache__/__init__.cpython-311.pyc b/src/style_transfer_ai1/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..be66553
Binary files /dev/null and b/src/style_transfer_ai1/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/style_transfer_ai1/utils/__pycache__/image_utils.cpython-311.pyc b/src/style_transfer_ai1/utils/__pycache__/image_utils.cpython-311.pyc
new file mode 100644
index 0000000..af3fa68
Binary files /dev/null and b/src/style_transfer_ai1/utils/__pycache__/image_utils.cpython-311.pyc differ
diff --git a/src/style_transfer_ai1/utils/image_utils.py b/src/style_transfer_ai1/utils/image_utils.py
new file mode 100644
index 0000000..e277e05
--- /dev/null
+++ b/src/style_transfer_ai1/utils/image_utils.py
@@ -0,0 +1,38 @@
+"""
+图像处理工具模块。
+
+提供加载和保存图像的便捷函数。
+"""
+
+import numpy as np
+from PIL import Image
+from typing import Tuple
+
+
+def load_image(path: str) -> np.ndarray:
+ """
+ 加载图像。
+
+ Args:
+ path: 图像文件路径。
+
+ Returns:
+ 图像数据,形状为 (H, W, C)。
+ """
+ with Image.open(path) as img:
+ # 转换为RGB,避免灰度图或RGBA图导致问题
+ if img.mode != 'RGB':
+ img = img.convert('RGB')
+ return np.array(img)
+
+
+def save_image(image: np.ndarray, path: str):
+ """
+ 保存图像。
+
+ Args:
+ image: 图像数据,形状为 (H, W, C)。
+ path: 保存路径。
+ """
+ img = Image.fromarray(image)
+ img.save(path)
\ No newline at end of file
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_core.py b/tests/test_core.py
new file mode 100644
index 0000000..49524cb
--- /dev/null
+++ b/tests/test_core.py
@@ -0,0 +1,22 @@
+from PIL import Image
+import numpy as np
+from pathlib import Path
+
+# 获取项目根目录
+current_file = Path(__file__) # D:\1\1\style_transfer_ai1\tests\test_core.py
+project_root = current_file.parent.parent # 上两级:tests -> style_transfer_ai1
+
+# 图片路径
+img_path = project_root / "data" / "raw" / "content.jpg"
+
+print(f"📁 正在加载图片: {img_path}")
+
+# 加载图像
+img = Image.open(img_path)
+img_array = np.array(img)
+
+# 保存原图(测试是否能写入)
+output_path = project_root / "data" / "processed" / "test_result.jpg"
+Image.fromarray(img_array).save(output_path)
+
+print("✅ 测试成功!")
\ No newline at end of file