commit
5aea9c39c0
@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.11 (BP-introduce)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (BP-introduce)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/style_transfer_ai.iml" filepath="$PROJECT_DIR$/.idea/style_transfer_ai.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.11 (BP-introduce)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/style_transfer_ai" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@ -0,0 +1,24 @@
|
||||
# .pre-commit-config.yaml
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--line-length=88]
|
||||
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: [--max-line-length=88]
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/python/mypy
|
||||
rev: v1.3.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--ignore-missing-imports]
|
||||
@ -0,0 +1,36 @@
|
||||
# configs/stable_config.yaml
|
||||
# 稳定训练配置
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
model_type: "vgg16"
|
||||
image_size: 192
|
||||
content_layers: ["19"] # conv4_2
|
||||
style_layers: ["0", "5", "10", "17", "24"] # 全部5个层
|
||||
content_weight: 1e4 # 增加内容权重
|
||||
style_weight: 5e4 # 降低风格权重
|
||||
learning_rate: 0.005 # 低学习率稳定训练
|
||||
|
||||
# 路径配置
|
||||
paths:
|
||||
content_image: "data/raw/content.jpg"
|
||||
style_image: "data/raw/tem.jpg"
|
||||
output_dir: "data/processed"
|
||||
|
||||
# 训练配置
|
||||
training:
|
||||
num_steps: 150 # 减少步数
|
||||
log_interval: 15
|
||||
tv_weight: 1e-5
|
||||
|
||||
# 多尺度配置(先禁用,稳定后再启用)
|
||||
multi_scale:
|
||||
enabled: false
|
||||
scales: [192]
|
||||
scale_weights: [1.0]
|
||||
|
||||
# 性能配置
|
||||
performance:
|
||||
use_gpu: false
|
||||
batch_size: 1
|
||||
precision: "float32"
|
||||
|
After Width: | Height: | Size: 41 KiB |
|
After Width: | Height: | Size: 35 KiB |
|
After Width: | Height: | Size: 41 KiB |
|
After Width: | Height: | Size: 168 KiB |
|
After Width: | Height: | Size: 5.7 MiB |
|
After Width: | Height: | Size: 9.6 MiB |
@ -0,0 +1,20 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "style-transfer-ai"
|
||||
version = "0.1.0"
|
||||
description = "A style transfer AI tool using neural networks"
|
||||
authors = [
|
||||
{name = "Your Name", email = "your.email@example.com"},
|
||||
]
|
||||
dependencies = [
|
||||
"torch>=2.0.0",
|
||||
"torchvision>=0.15.0",
|
||||
"numpy>=1.24.0",
|
||||
"Pillow>=9.0.0",
|
||||
"scipy>=1.10.0",
|
||||
"PyYAML>=6.0.0",
|
||||
]
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
numpy>=1.24.0
|
||||
Pillow>=9.0.0
|
||||
scipy>=1.10.0
|
||||
PyYAML>=6.0.0
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,82 @@
|
||||
"""
|
||||
多尺度风格迁移处理器。
|
||||
在不同尺度上处理图像以获得更好的结果。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Dict, Any, List
|
||||
import cv2
|
||||
|
||||
|
||||
class MultiScaleProcessor:
|
||||
"""
|
||||
多尺度图像处理器。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.scales = config.get('scales', [512, 256]) # 从大到小
|
||||
self.scale_weights = config.get('scale_weights', [1.0, 0.5])
|
||||
|
||||
def process_multi_scale(self,
|
||||
content_image: np.ndarray,
|
||||
style_image: np.ndarray,
|
||||
synthesizer: Any) -> np.ndarray:
|
||||
"""
|
||||
多尺度风格迁移。
|
||||
|
||||
Args:
|
||||
content_image: 内容图像
|
||||
style_image: 风格图像
|
||||
synthesizer: 图像合成器
|
||||
|
||||
Returns:
|
||||
合成后的图像
|
||||
"""
|
||||
current_result = None
|
||||
|
||||
for i, scale in enumerate(self.scales):
|
||||
print(f"处理尺度: {scale}x{scale}")
|
||||
|
||||
# 调整图像尺寸
|
||||
content_scaled = self._resize_image(content_image, scale)
|
||||
style_scaled = self._resize_image(style_image, scale)
|
||||
|
||||
# 如果是第一次迭代,使用内容图像作为初始值
|
||||
# 否则使用上一步的结果(调整到当前尺度)
|
||||
if current_result is None:
|
||||
initial_image = None
|
||||
else:
|
||||
initial_image = self._resize_image(current_result, scale)
|
||||
|
||||
# 在当前尺度上合成
|
||||
# 这里需要根据你的具体合成器接口调整
|
||||
current_result = synthesizer.synthesize(
|
||||
content_scaled,
|
||||
style_scaled,
|
||||
initial_image=initial_image,
|
||||
scale_weight=self.scale_weights[i]
|
||||
)
|
||||
|
||||
return current_result
|
||||
|
||||
def _resize_image(self, image: np.ndarray, size: int) -> np.ndarray:
|
||||
"""调整图像尺寸"""
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if h > w:
|
||||
new_h = size
|
||||
new_w = int(w * size / h)
|
||||
else:
|
||||
new_w = size
|
||||
new_h = int(h * size / w)
|
||||
|
||||
# 使用高质量的重采样方法
|
||||
pil_image = Image.fromarray(image)
|
||||
resized = pil_image.resize((new_w, new_h), Image.LANCZOS)
|
||||
|
||||
return np.array(resized)
|
||||
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue