|
|
|
|
@ -1,995 +0,0 @@
|
|
|
|
|
Python项目代码规范
|
|
|
|
|
代码风格规范
|
|
|
|
|
(1)格式规范
|
|
|
|
|
遵循PEP 8风格指南,这是Python官方推荐的代码风格指南。
|
|
|
|
|
●缩进:使用4个空格(不要使用制表符)。
|
|
|
|
|
●每行最大长度:不超过79个字符(对于文档字符串和注释,不超过72个字符)。
|
|
|
|
|
●空行:使用空行分隔函数和类,以及函数内的逻辑块。
|
|
|
|
|
●导入:每个导入应该单独成行,且导入顺序如下:
|
|
|
|
|
●标准库导入
|
|
|
|
|
相关第三方库导入
|
|
|
|
|
本地应用/库的特定导入
|
|
|
|
|
避免使用空格:在括号、方括号或花括号内部避免不必要的空格。
|
|
|
|
|
|
|
|
|
|
●自动格式化工具:
|
|
|
|
|
使用black、autopep8或yapf等工具自动格式化代码。
|
|
|
|
|
使用isort对导入进行排序。
|
|
|
|
|
|
|
|
|
|
(2)命名规范
|
|
|
|
|
●变量和函数:使用小写字母,单词之间用下划线连接(蛇形命名法),例如:my_variable。
|
|
|
|
|
●常量:使用全大写字母,单词之间用下划线连接,例如:MY_CONSTANT。
|
|
|
|
|
●类:使用驼峰命名法(每个单词首字母大写,不加下划线),例如:MyClass。
|
|
|
|
|
●模块和包:使用简短的小写字母命名,避免使用下划线(如果可读性允许)。
|
|
|
|
|
|
|
|
|
|
(3)注释和文档字符串
|
|
|
|
|
●注释:注释应该是完整的句子,解释代码的意图,而不是描述代码本身。避免不必要的注释。
|
|
|
|
|
●文档字符串(Docstring):使用PEP 257中定义的文档字符串约定。
|
|
|
|
|
●模块、类、函数的第一个语句应该是文档字符串。
|
|
|
|
|
●使用三重双引号"""。
|
|
|
|
|
●对于函数和方法的文档字符串,应说明其功能、参数、返回值和异常。
|
|
|
|
|
|
|
|
|
|
1. 项目结构规范
|
|
|
|
|
1.1 标准化项目布局
|
|
|
|
|
project_name/
|
|
|
|
|
├── 📁 src/ # 源代码目录(推荐使用src布局)
|
|
|
|
|
│ └── package_name/
|
|
|
|
|
│ ├── __init__.py # 包初始化文件
|
|
|
|
|
│ ├── core/ # 核心模块
|
|
|
|
|
│ │ ├── __init__.py
|
|
|
|
|
│ │ ├── models.py # 数据模型定义
|
|
|
|
|
│ │ └── processors.py # 核心处理器
|
|
|
|
|
│ └── utils/ # 工具模块
|
|
|
|
|
│ ├── __init__.py
|
|
|
|
|
│ ├── data_utils.py
|
|
|
|
|
│ └── visualization.py
|
|
|
|
|
├── 📁 tests/ # 测试代码
|
|
|
|
|
│ ├── __init__.py
|
|
|
|
|
│ ├── test_core/
|
|
|
|
|
│ └── test_utils/
|
|
|
|
|
├── 📁 docs/ # 项目文档
|
|
|
|
|
│ ├── conf.py
|
|
|
|
|
│ ├── index.rst
|
|
|
|
|
│ └── api/
|
|
|
|
|
├── 📁 configs/ # 配置文件
|
|
|
|
|
│ ├── default.yaml
|
|
|
|
|
│ ├── development.yaml
|
|
|
|
|
│ └── production.yaml
|
|
|
|
|
├── 📁 data/ # 数据目录
|
|
|
|
|
│ ├── raw/ # 原始数据
|
|
|
|
|
│ ├── processed/ # 处理后的数据
|
|
|
|
|
│ └── external/ # 外部数据
|
|
|
|
|
├── 📁 notebooks/ # Jupyter笔记本
|
|
|
|
|
│ ├── 01-exploratory-analysis.ipynb
|
|
|
|
|
│ └── 02-model-experiments.ipynb
|
|
|
|
|
├── 📁 scripts/ # 脚本文件
|
|
|
|
|
│ ├── train.py
|
|
|
|
|
│ ├── evaluate.py
|
|
|
|
|
│ └── deploy.py
|
|
|
|
|
├── 📄 requirements.txt # 依赖列表
|
|
|
|
|
├── 📄 setup.py # 安装配置
|
|
|
|
|
├── 📄 pyproject.toml # 现代项目配置(推荐)
|
|
|
|
|
├── 📄 README.md # 项目说明
|
|
|
|
|
├── 📄 .gitignore # Git忽略规则
|
|
|
|
|
├── 📄 .pre-commit-config.yaml # 预提交钩子配置
|
|
|
|
|
└── 📄 Makefile # 常用命令封装
|
|
|
|
|
|
|
|
|
|
README.md:包含项目名称、描述、安装步骤、使用示例等。
|
|
|
|
|
requirements.txt:列出项目依赖的第三方库。
|
|
|
|
|
setup.py:用于打包和分发项目。
|
|
|
|
|
docs/:存放项目文档。
|
|
|
|
|
tests/:存放所有测试代码,使用unittest、pytest等框架。
|
|
|
|
|
src/ 或直接使用包名:源代码目录。
|
|
|
|
|
|
|
|
|
|
1.2 目录结构详解
|
|
|
|
|
src布局优势:
|
|
|
|
|
- 避免将项目目录意外添加到Python路径
|
|
|
|
|
- 确保测试针对已安装的包进行
|
|
|
|
|
- 提供清晰的包边界
|
|
|
|
|
|
|
|
|
|
__init__.py文件的作用:
|
|
|
|
|
# package_name/__init__.py
|
|
|
|
|
"""包级别的文档字符串"""
|
|
|
|
|
|
|
|
|
|
from .core.models import BaseModel
|
|
|
|
|
from .utils.data_utils import DataProcessor
|
|
|
|
|
|
|
|
|
|
# 定义包的公开API
|
|
|
|
|
__all__ = ['BaseModel', 'DataProcessor']
|
|
|
|
|
__version__ = "1.0.0"
|
|
|
|
|
|
|
|
|
|
2. 代码风格与规范
|
|
|
|
|
2.1 PEP 8详细规范
|
|
|
|
|
命名约定
|
|
|
|
|
# ✅ 推荐的命名方式
|
|
|
|
|
class DeepLearningModel: # 类名:驼峰式
|
|
|
|
|
def __init__(self): # 方法名:小写+下划线
|
|
|
|
|
self.model_config = {} # 实例变量:小写+下划线
|
|
|
|
|
self._hidden_layers = [] # 保护变量:单下划线开头
|
|
|
|
|
self.__private_param = 0 # 私有变量:双下划线开头
|
|
|
|
|
|
|
|
|
|
MAX_EPOCHS = 100 # 常量:全大写+下划线
|
|
|
|
|
|
|
|
|
|
def train_model(self): # 公共方法
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _internal_helper(self): # 内部方法
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# 模块和包命名
|
|
|
|
|
data_processor.py # 模块:小写,简短
|
|
|
|
|
neural_networks/ # 包:小写,不含连字符
|
|
|
|
|
|
|
|
|
|
代码布局规范
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
模块文档字符串:简要描述模块功能
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 导入标准库
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
# 导入第三方库
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
# 导入本地模块
|
|
|
|
|
from .utils import data_processor
|
|
|
|
|
from .core.base_model import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NeuralNetwork(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
神经网络模型的详细说明。
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
input_size (int): 输入层维度
|
|
|
|
|
hidden_sizes (List[int]): 隐藏层维度列表
|
|
|
|
|
output_size (int): 输出层维度
|
|
|
|
|
activation (str): 激活函数类型
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
input_size: int,
|
|
|
|
|
hidden_sizes: List[int],
|
|
|
|
|
output_size: int,
|
|
|
|
|
activation: str = "relu"
|
|
|
|
|
) -> None:
|
|
|
|
|
"""初始化神经网络。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input_size: 输入特征维度
|
|
|
|
|
hidden_sizes: 隐藏层大小列表
|
|
|
|
|
output_size: 输出层维度
|
|
|
|
|
activation: 激活函数,支持 'relu', 'sigmoid', 'tanh'
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: 当激活函数不支持时
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.input_size = input_size
|
|
|
|
|
self.hidden_sizes = hidden_sizes
|
|
|
|
|
self.output_size = output_size
|
|
|
|
|
|
|
|
|
|
# 参数验证
|
|
|
|
|
if activation not in ["relu", "sigmoid", "tanh"]:
|
|
|
|
|
raise ValueError(f"不支持的激活函数: {activation}")
|
|
|
|
|
|
|
|
|
|
self.layers = self._build_layers()
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""前向传播计算。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: 输入张量,形状为 (batch_size, input_size)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
输出张量,形状为 (batch_size, output_size)
|
|
|
|
|
"""
|
|
|
|
|
for layer in self.layers:
|
|
|
|
|
x = layer(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _build_layers(self) -> nn.ModuleList:
|
|
|
|
|
"""构建网络层。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
包含所有网络层的ModuleList
|
|
|
|
|
"""
|
|
|
|
|
layers = nn.ModuleList()
|
|
|
|
|
sizes = [self.input_size] + self.hidden_sizes + [self.output_size]
|
|
|
|
|
|
|
|
|
|
for i in range(len(sizes) - 1):
|
|
|
|
|
layers.append(nn.Linear(sizes[i], sizes[i + 1]))
|
|
|
|
|
if i < len(sizes) - 2: # 最后一层不加激活函数
|
|
|
|
|
layers.append(self._get_activation())
|
|
|
|
|
|
|
|
|
|
return layers
|
|
|
|
|
|
|
|
|
|
2.2 类型注解规范
|
|
|
|
|
from typing import (
|
|
|
|
|
Any, Callable, Dict, List, Optional, Tuple,
|
|
|
|
|
Union, Sequence, Iterator
|
|
|
|
|
)
|
|
|
|
|
import pandas as pd
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
def load_dataset(
|
|
|
|
|
file_path: Union[str, Path],
|
|
|
|
|
target_column: Optional[str] = None,
|
|
|
|
|
feature_columns: Sequence[str] = None,
|
|
|
|
|
preprocessors: Optional[List[Callable]] = None
|
|
|
|
|
) -> Tuple[pd.DataFrame, Optional[pd.Series]]:
|
|
|
|
|
"""加载和预处理数据集。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
file_path: 数据文件路径
|
|
|
|
|
target_column: 目标列名,如为None则只返回特征
|
|
|
|
|
feature_columns: 特征列列表,如为None则使用所有列
|
|
|
|
|
preprocessors: 数据预处理函数列表
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
包含特征数据和目标数据的元组
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
>>> X, y = load_dataset("data.csv", target_column="label")
|
|
|
|
|
"""
|
|
|
|
|
# 实现代码...
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# 使用TypeVar进行泛型编程
|
|
|
|
|
from typing import TypeVar, Generic
|
|
|
|
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
|
|
|
|
class DataLoader(Generic[T]):
|
|
|
|
|
"""通用数据加载器。"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, data: List[T], batch_size: int = 32):
|
|
|
|
|
self.data = data
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[List[T]]:
|
|
|
|
|
for i in range(0, len(self.data), self.batch_size):
|
|
|
|
|
yield self.data[i:i + self.batch_size]
|
|
|
|
|
|
|
|
|
|
3. 配置管理规范
|
|
|
|
|
●使用pylint、flake8或mypy(用于类型检查)等工具进行静态分析,确保代码质量。
|
|
|
|
|
●在Python 3.5+中,使用类型提示(Type Hints)来标注变量、函数参数和返回值的类型。
|
|
|
|
|
●这可以提高代码的可读性和可维护性,并允许使用静态类型检查工具。
|
|
|
|
|
|
|
|
|
|
3.1 配置文件设计
|
|
|
|
|
```yaml
|
|
|
|
|
# configs/default.yaml
|
|
|
|
|
project:
|
|
|
|
|
name: "deep_learning_project"
|
|
|
|
|
version: "1.0.0"
|
|
|
|
|
description: "深度学习项目模板"
|
|
|
|
|
|
|
|
|
|
data:
|
|
|
|
|
input_path: "./data/raw"
|
|
|
|
|
output_path: "./data/processed"
|
|
|
|
|
batch_size: 64
|
|
|
|
|
num_workers: 8
|
|
|
|
|
train_ratio: 0.8
|
|
|
|
|
validation_ratio: 0.1
|
|
|
|
|
# test_ratio 自动计算为 0.1
|
|
|
|
|
|
|
|
|
|
model:
|
|
|
|
|
architecture: "resnet50"
|
|
|
|
|
input_size: 224
|
|
|
|
|
num_classes: 1000
|
|
|
|
|
pretrained: true
|
|
|
|
|
dropout_rate: 0.5
|
|
|
|
|
|
|
|
|
|
training:
|
|
|
|
|
optimizer:
|
|
|
|
|
name: "adam"
|
|
|
|
|
learning_rate: 0.001
|
|
|
|
|
weight_decay: 0.0001
|
|
|
|
|
scheduler:
|
|
|
|
|
name: "cosine"
|
|
|
|
|
warmup_epochs: 5
|
|
|
|
|
epochs: 100
|
|
|
|
|
early_stopping_patience: 10
|
|
|
|
|
|
|
|
|
|
logging:
|
|
|
|
|
level: "INFO"
|
|
|
|
|
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
|
tensorboard_dir: "./logs/tensorboard"
|
|
|
|
|
checkpoint_dir: "./checkpoints"
|
|
|
|
|
|
|
|
|
|
3.2 配置加载与管理
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from typing import List, Optional, Dict, Any
|
|
|
|
|
import yaml
|
|
|
|
|
import json
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class OptimizerConfig:
|
|
|
|
|
"""优化器配置。"""
|
|
|
|
|
name: str = "adam"
|
|
|
|
|
learning_rate: float = 0.001
|
|
|
|
|
weight_decay: float = 0.0001
|
|
|
|
|
momentum: Optional[float] = None
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TrainingConfig:
|
|
|
|
|
"""训练配置。"""
|
|
|
|
|
epochs: int = 100
|
|
|
|
|
batch_size: int = 32
|
|
|
|
|
early_stopping_patience: int = 10
|
|
|
|
|
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ModelConfig:
|
|
|
|
|
"""模型配置。"""
|
|
|
|
|
architecture: str = "resnet50"
|
|
|
|
|
input_size: int = 224
|
|
|
|
|
num_classes: int = 1000
|
|
|
|
|
pretrained: bool = True
|
|
|
|
|
dropout_rate: float = 0.5
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ProjectConfig:
|
|
|
|
|
"""项目配置。"""
|
|
|
|
|
project_name: str
|
|
|
|
|
version: str
|
|
|
|
|
model: ModelConfig = field(default_factory=ModelConfig)
|
|
|
|
|
training: TrainingConfig = field(default_factory=TrainingConfig)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_yaml(cls, config_path: Path) -> "ProjectConfig":
|
|
|
|
|
"""从YAML文件加载配置。"""
|
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
config_dict = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
return cls._from_dict(config_dict)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _from_dict(cls, config_dict: Dict[str, Any]) -> "ProjectConfig":
|
|
|
|
|
"""从字典创建配置对象。"""
|
|
|
|
|
# 递归处理嵌套配置
|
|
|
|
|
# 实际实现中需要更复杂的字典到数据类的转换
|
|
|
|
|
return cls(**config_dict)
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
|
|
"""转换为字典。"""
|
|
|
|
|
# 实现数据类到字典的转换
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
4. 模型开发规范
|
|
|
|
|
4.1 基础模型抽象
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import Dict, Any, Optional
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
class BaseModel(nn.Module, ABC):
|
|
|
|
|
"""所有模型的基类。"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.config = config
|
|
|
|
|
self._setup_logging()
|
|
|
|
|
self._validate_config()
|
|
|
|
|
|
|
|
|
|
def _setup_logging(self):
|
|
|
|
|
"""设置模型特定的日志。"""
|
|
|
|
|
self.logger = logging.getLogger(
|
|
|
|
|
f"{__name__}.{self.__class__.__name__}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _validate_config(self):
|
|
|
|
|
"""验证配置参数。"""
|
|
|
|
|
required_keys = ['input_size', 'output_size']
|
|
|
|
|
for key in required_keys:
|
|
|
|
|
if key not in self.config:
|
|
|
|
|
raise ValueError(f"缺少必需的配置参数: {key}")
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""前向传播。"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_loss(self, outputs: torch.Tensor,
|
|
|
|
|
targets: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""计算损失。"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def save(self, filepath: Path):
|
|
|
|
|
"""保存模型。"""
|
|
|
|
|
checkpoint = {
|
|
|
|
|
'model_state_dict': self.state_dict(),
|
|
|
|
|
'config': self.config,
|
|
|
|
|
'model_class': self.__class__.__name__
|
|
|
|
|
}
|
|
|
|
|
torch.save(checkpoint, filepath)
|
|
|
|
|
self.logger.info(f"模型已保存到: {filepath}")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def load(cls, filepath: Path) -> "BaseModel":
|
|
|
|
|
"""加载模型。"""
|
|
|
|
|
checkpoint = torch.load(filepath, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
# 创建模型实例
|
|
|
|
|
model = cls(checkpoint['config'])
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
|
|
|
|
logger.info(f"模型已从 {filepath} 加载")
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
def count_parameters(self) -> int:
|
|
|
|
|
"""统计模型参数数量。"""
|
|
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
|
def freeze_layers(self, layer_names: List[str]):
|
|
|
|
|
"""冻结指定层。"""
|
|
|
|
|
for name, param in self.named_parameters():
|
|
|
|
|
if any(layer_name in name for layer_name in layer_names):
|
|
|
|
|
param.requires_grad = False
|
|
|
|
|
self.logger.debug(f"冻结层: {name}")
|
|
|
|
|
|
|
|
|
|
4.2 具体模型实现
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
class CNNModel(BaseModel):
|
|
|
|
|
"""卷积神经网络模型。"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
# 构建卷积层
|
|
|
|
|
self.conv_layers = self._build_conv_layers()
|
|
|
|
|
|
|
|
|
|
# 构建全连接层
|
|
|
|
|
self.fc_layers = self._build_fc_layers()
|
|
|
|
|
|
|
|
|
|
# 初始化权重
|
|
|
|
|
self._initialize_weights()
|
|
|
|
|
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"模型初始化完成,总参数量: {self.count_parameters():,}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _build_conv_layers(self) -> nn.Sequential:
|
|
|
|
|
"""构建卷积层。"""
|
|
|
|
|
layers = []
|
|
|
|
|
in_channels = self.config.get('input_channels', 3)
|
|
|
|
|
conv_channels = self.config.get('conv_channels', [64, 128, 256])
|
|
|
|
|
|
|
|
|
|
for i, out_channels in enumerate(conv_channels):
|
|
|
|
|
layers.extend([
|
|
|
|
|
nn.Conv2d(in_channels, out_channels,
|
|
|
|
|
kernel_size=3, padding=1),
|
|
|
|
|
nn.BatchNorm2d(out_channels),
|
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
nn.MaxPool2d(2)
|
|
|
|
|
])
|
|
|
|
|
in_channels = out_channels
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
def _build_fc_layers(self) -> nn.Sequential:
|
|
|
|
|
"""构建全连接层。"""
|
|
|
|
|
fc_sizes = self.config.get('fc_sizes', [512])
|
|
|
|
|
dropout_rate = self.config.get('dropout_rate', 0.5)
|
|
|
|
|
|
|
|
|
|
layers = []
|
|
|
|
|
in_features = self._calculate_fc_input_size()
|
|
|
|
|
|
|
|
|
|
for out_features in fc_sizes:
|
|
|
|
|
layers.extend([
|
|
|
|
|
nn.Linear(in_features, out_features),
|
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
nn.Dropout(dropout_rate)
|
|
|
|
|
])
|
|
|
|
|
in_features = out_features
|
|
|
|
|
|
|
|
|
|
# 输出层
|
|
|
|
|
layers.append(
|
|
|
|
|
nn.Linear(in_features, self.config['output_size'])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
def _calculate_fc_input_size(self) -> int:
|
|
|
|
|
"""计算全连接层输入尺寸。"""
|
|
|
|
|
# 通过前向传播一个虚拟输入来计算尺寸
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
dummy_input = torch.zeros(1, 3, 224, 224)
|
|
|
|
|
dummy_output = self.conv_layers(dummy_input)
|
|
|
|
|
return dummy_output.view(1, -1).size(1)
|
|
|
|
|
|
|
|
|
|
def _initialize_weights(self):
|
|
|
|
|
"""初始化权重。"""
|
|
|
|
|
for module in self.modules():
|
|
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
|
|
nn.init.kaiming_normal_(
|
|
|
|
|
module.weight, mode='fan_out', nonlinearity='relu'
|
|
|
|
|
)
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
|
|
|
nn.init.constant_(module.weight, 1)
|
|
|
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
|
elif isinstance(module, nn.Linear):
|
|
|
|
|
nn.init.normal_(module.weight, 0, 0.01)
|
|
|
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""前向传播。"""
|
|
|
|
|
x = self.conv_layers(x)
|
|
|
|
|
x = x.view(x.size(0), -1) # 展平
|
|
|
|
|
x = self.fc_layers(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def get_loss(self, outputs: torch.Tensor,
|
|
|
|
|
targets: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""计算交叉熵损失。"""
|
|
|
|
|
return F.cross_entropy(outputs, targets)
|
|
|
|
|
|
|
|
|
|
5. 训练流程规范
|
|
|
|
|
5.1 训练器设计
|
|
|
|
|
import time
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Dict, List, Optional, Callable
|
|
|
|
|
import numpy as np
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
|
|
"""模型训练器。"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model: BaseModel,
|
|
|
|
|
train_loader: torch.utils.data.DataLoader,
|
|
|
|
|
val_loader: torch.utils.data.DataLoader,
|
|
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
|
|
|
|
device: torch.device = torch.device('cpu'),
|
|
|
|
|
config: Dict[str, Any] = None
|
|
|
|
|
):
|
|
|
|
|
self.model = model
|
|
|
|
|
self.train_loader = train_loader
|
|
|
|
|
self.val_loader = val_loader
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.scheduler = scheduler
|
|
|
|
|
self.device = device
|
|
|
|
|
self.config = config or {}
|
|
|
|
|
|
|
|
|
|
# 训练状态
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
self.best_metric = float('inf')
|
|
|
|
|
self.train_losses: List[float] = []
|
|
|
|
|
self.val_losses: List[float] = []
|
|
|
|
|
self.learning_rates: List[float] = []
|
|
|
|
|
|
|
|
|
|
self._setup_experiment_tracking()
|
|
|
|
|
|
|
|
|
|
def _setup_experiment_tracking(self):
|
|
|
|
|
"""设置实验跟踪。"""
|
|
|
|
|
experiment_name = self.config.get(
|
|
|
|
|
'experiment_name',
|
|
|
|
|
f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
|
|
|
)
|
|
|
|
|
self.experiment_dir = Path("experiments") / experiment_name
|
|
|
|
|
self.experiment_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# 保存配置
|
|
|
|
|
if self.config:
|
|
|
|
|
import yaml
|
|
|
|
|
with open(self.experiment_dir / "config.yaml", 'w') as f:
|
|
|
|
|
yaml.dump(self.config, f)
|
|
|
|
|
|
|
|
|
|
def train_epoch(self) -> float:
|
|
|
|
|
"""训练一个epoch。"""
|
|
|
|
|
self.model.train()
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
num_batches = 0
|
|
|
|
|
|
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {self.epoch+1} Training")
|
|
|
|
|
for batch_idx, (data, target) in enumerate(pbar):
|
|
|
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
|
|
|
|
|
|
|
|
# 前向传播
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
output = self.model(data)
|
|
|
|
|
loss = self.model.get_loss(output, target)
|
|
|
|
|
|
|
|
|
|
# 反向传播
|
|
|
|
|
loss.backward()
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
num_batches += 1
|
|
|
|
|
|
|
|
|
|
# 更新进度条
|
|
|
|
|
pbar.set_postfix({
|
|
|
|
|
'loss': f'{loss.item():.6f}',
|
|
|
|
|
'avg_loss': f'{total_loss/num_batches:.6f}'
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
avg_loss = total_loss / num_batches
|
|
|
|
|
self.train_losses.append(avg_loss)
|
|
|
|
|
return avg_loss
|
|
|
|
|
|
|
|
|
|
def validate(self) -> Dict[str, float]:
|
|
|
|
|
"""验证模型。"""
|
|
|
|
|
self.model.eval()
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
correct = 0
|
|
|
|
|
total = 0
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for data, target in tqdm(self.val_loader, desc="Validation"):
|
|
|
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
|
|
|
output = self.model(data)
|
|
|
|
|
|
|
|
|
|
total_loss += self.model.get_loss(output, target).item()
|
|
|
|
|
pred = output.argmax(dim=1)
|
|
|
|
|
correct += pred.eq(target).sum().item()
|
|
|
|
|
total += target.size(0)
|
|
|
|
|
|
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
|
|
|
accuracy = 100. * correct / total
|
|
|
|
|
|
|
|
|
|
metrics = {
|
|
|
|
|
'loss': avg_loss,
|
|
|
|
|
'accuracy': accuracy
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
self.val_losses.append(avg_loss)
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
def train(self, epochs: int) -> Dict[str, List[float]]:
|
|
|
|
|
"""完整的训练流程。"""
|
|
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
|
|
self.epoch = epoch
|
|
|
|
|
|
|
|
|
|
# 训练
|
|
|
|
|
train_loss = self.train_epoch()
|
|
|
|
|
|
|
|
|
|
# 验证
|
|
|
|
|
val_metrics = self.validate()
|
|
|
|
|
|
|
|
|
|
# 学习率调度
|
|
|
|
|
if self.scheduler:
|
|
|
|
|
self.scheduler.step()
|
|
|
|
|
self.learning_rates.append(
|
|
|
|
|
self.optimizer.param_groups[0]['lr']
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 记录和保存
|
|
|
|
|
self._log_epoch(epoch, train_loss, val_metrics)
|
|
|
|
|
self._save_checkpoint(val_metrics)
|
|
|
|
|
|
|
|
|
|
# 早停检查
|
|
|
|
|
if self._check_early_stopping():
|
|
|
|
|
print("早停触发,停止训练")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
return self._get_training_history()
|
|
|
|
|
|
|
|
|
|
def _log_epoch(self, epoch: int, train_loss: float,
|
|
|
|
|
val_metrics: Dict[str, float]):
|
|
|
|
|
"""记录epoch结果。"""
|
|
|
|
|
log_msg = (
|
|
|
|
|
f"Epoch {epoch+1:03d}: "
|
|
|
|
|
f"Train Loss: {train_loss:.6f}, "
|
|
|
|
|
f"Val Loss: {val_metrics['loss']:.6f}, "
|
|
|
|
|
f"Val Acc: {val_metrics['accuracy']:.2f}%"
|
|
|
|
|
)
|
|
|
|
|
if self.scheduler:
|
|
|
|
|
log_msg += f", LR: {self.learning_rates[-1]:.2e}"
|
|
|
|
|
|
|
|
|
|
print(log_msg)
|
|
|
|
|
|
|
|
|
|
def _save_checkpoint(self, val_metrics: Dict[str, float]):
|
|
|
|
|
"""保存检查点。"""
|
|
|
|
|
current_metric = val_metrics['loss']
|
|
|
|
|
|
|
|
|
|
# 保存最佳模型
|
|
|
|
|
if current_metric < self.best_metric:
|
|
|
|
|
self.best_metric = current_metric
|
|
|
|
|
checkpoint_path = self.experiment_dir / "best_model.pth"
|
|
|
|
|
self.model.save(checkpoint_path)
|
|
|
|
|
|
|
|
|
|
def _check_early_stopping(self) -> bool:
|
|
|
|
|
"""检查早停条件。"""
|
|
|
|
|
patience = self.config.get('early_stopping_patience', 10)
|
|
|
|
|
if len(self.val_losses) < patience:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
recent_losses = self.val_losses[-patience:]
|
|
|
|
|
return min(recent_losses) >= self.best_metric
|
|
|
|
|
|
|
|
|
|
def _get_training_history(self) -> Dict[str, List[float]]:
|
|
|
|
|
"""获取训练历史。"""
|
|
|
|
|
return {
|
|
|
|
|
'train_loss': self.train_losses,
|
|
|
|
|
'val_loss': self.val_losses,
|
|
|
|
|
'learning_rate': self.learning_rates
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
6. 测试规范
|
|
|
|
|
测试框架
|
|
|
|
|
●使用unittest、pytest或nose等测试框架。
|
|
|
|
|
●测试代码应该放在tests目录中,测试文件以test_开头,测试函数/类以Test开头。
|
|
|
|
|
测试覆盖度
|
|
|
|
|
●使用coverage.py等工具测量测试覆盖度,并尽量提高覆盖度。
|
|
|
|
|
|
|
|
|
|
6.1 单元测试
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
|
|
|
|
|
|
class TestCNNModel:
|
|
|
|
|
"""CNN模型测试类。"""
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def sample_config(self):
|
|
|
|
|
"""样本配置。"""
|
|
|
|
|
return {
|
|
|
|
|
'input_channels': 3,
|
|
|
|
|
'output_size': 10,
|
|
|
|
|
'conv_channels': [16, 32],
|
|
|
|
|
'fc_sizes': [64]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def sample_model(self, sample_config):
|
|
|
|
|
"""样本模型。"""
|
|
|
|
|
return CNNModel(sample_config)
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def sample_batch(self):
|
|
|
|
|
"""样本批次数据。"""
|
|
|
|
|
return torch.randn(4, 3, 32, 32), torch.randint(0, 10, (4,))
|
|
|
|
|
|
|
|
|
|
def test_model_initialization(self, sample_model, sample_config):
|
|
|
|
|
"""测试模型初始化。"""
|
|
|
|
|
assert sample_model.config == sample_config
|
|
|
|
|
assert isinstance(sample_model.conv_layers, nn.Sequential)
|
|
|
|
|
assert isinstance(sample_model.fc_layers, nn.Sequential)
|
|
|
|
|
|
|
|
|
|
def test_forward_pass(self, sample_model, sample_batch):
|
|
|
|
|
"""测试前向传播。"""
|
|
|
|
|
data, _ = sample_batch
|
|
|
|
|
output = sample_model(data)
|
|
|
|
|
|
|
|
|
|
assert output.shape == (4, 10) # batch_size, num_classes
|
|
|
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
assert not torch.isinf(output).any()
|
|
|
|
|
|
|
|
|
|
def test_loss_computation(self, sample_model, sample_batch):
|
|
|
|
|
"""测试损失计算。"""
|
|
|
|
|
data, target = sample_batch
|
|
|
|
|
output = sample_model(data)
|
|
|
|
|
loss = sample_model.get_loss(output, target)
|
|
|
|
|
|
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
assert loss.item() > 0
|
|
|
|
|
assert not torch.isnan(loss)
|
|
|
|
|
|
|
|
|
|
def test_parameter_count(self, sample_model):
|
|
|
|
|
"""测试参数统计。"""
|
|
|
|
|
num_params = sample_model.count_parameters()
|
|
|
|
|
assert num_params > 0
|
|
|
|
|
assert isinstance(num_params, int)
|
|
|
|
|
|
|
|
|
|
@patch('torch.save')
|
|
|
|
|
def test_model_saving(self, mock_save, sample_model, tmp_path):
|
|
|
|
|
"""测试模型保存。"""
|
|
|
|
|
save_path = tmp_path / "test_model.pth"
|
|
|
|
|
sample_model.save(save_path)
|
|
|
|
|
|
|
|
|
|
mock_save.assert_called_once()
|
|
|
|
|
# 验证保存的内容结构
|
|
|
|
|
call_args = mock_save.call_args[0]
|
|
|
|
|
saved_checkpoint = call_args[0]
|
|
|
|
|
assert 'model_state_dict' in saved_checkpoint
|
|
|
|
|
assert 'config' in saved_checkpoint
|
|
|
|
|
|
|
|
|
|
def test_model_loading(self, sample_config, tmp_path):
|
|
|
|
|
"""测试模型加载。"""
|
|
|
|
|
# 创建并保存模型
|
|
|
|
|
original_model = CNNModel(sample_config)
|
|
|
|
|
save_path = tmp_path / "model.pth"
|
|
|
|
|
|
|
|
|
|
# 模拟保存(实际项目中需要完整实现)
|
|
|
|
|
torch.save({
|
|
|
|
|
'model_state_dict': original_model.state_dict(),
|
|
|
|
|
'config': sample_config
|
|
|
|
|
}, save_path)
|
|
|
|
|
|
|
|
|
|
# 测试加载功能
|
|
|
|
|
with patch.object(CNNModel, 'load') as mock_load:
|
|
|
|
|
mock_load.return_value = original_model
|
|
|
|
|
loaded_model = CNNModel.load(save_path)
|
|
|
|
|
|
|
|
|
|
assert loaded_model is original_model
|
|
|
|
|
|
|
|
|
|
7. 文档规范
|
|
|
|
|
●使用Sphinx等工具生成项目文档。
|
|
|
|
|
●文档应该包括:安装指南、快速开始、教程、API参考等。
|
|
|
|
|
●除了文档字符串,在复杂逻辑处应添加注释。
|
|
|
|
|
|
|
|
|
|
7.1 代码文档
|
|
|
|
|
"""
|
|
|
|
|
深度学习项目核心模块。
|
|
|
|
|
|
|
|
|
|
这个模块提供了深度学习项目的基础组件,
|
|
|
|
|
包括模型定义、训练流程和工具函数。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
class AdvancedModel(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
高级神经网络模型实现。
|
|
|
|
|
|
|
|
|
|
这个模型结合了多种先进的深度学习技术,
|
|
|
|
|
包括残差连接、注意力机制等。
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> config = {
|
|
|
|
|
... 'input_size': 784,
|
|
|
|
|
... 'hidden_sizes': [512, 256],
|
|
|
|
|
... 'output_size': 10,
|
|
|
|
|
... 'use_residual': True
|
|
|
|
|
... }
|
|
|
|
|
>>> model = AdvancedModel(config)
|
|
|
|
|
>>> print(f"参数数量: {model.count_parameters():,}")
|
|
|
|
|
参数数量: 1,234,567
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
这个模型需要PyTorch 1.9.0或更高版本。
|
|
|
|
|
|
|
|
|
|
Warning:
|
|
|
|
|
当使用大批次训练时,可能需要调整学习率。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
初始化高级模型。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
config: 模型配置字典,必须包含:
|
|
|
|
|
- input_size: 输入维度
|
|
|
|
|
- hidden_sizes: 隐藏层大小列表
|
|
|
|
|
- output_size: 输出维度
|
|
|
|
|
- use_residual: 是否使用残差连接
|
|
|
|
|
- dropout_rate: Dropout比率
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ConfigError: 当配置缺少必需参数时
|
|
|
|
|
ValueError: 当参数值无效时
|
|
|
|
|
"""
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
self._build_advanced_architecture()
|
|
|
|
|
|
|
|
|
|
7.2 API文档生成
|
|
|
|
|
●使用.gitignore文件忽略不需要版本控制的文件(如__pycache__/, *.pyc等)。
|
|
|
|
|
●提交信息应清晰描述本次提交的内容。
|
|
|
|
|
●使用requirements.txt或Pipfile(用于Pipenv)来管理依赖。
|
|
|
|
|
●使用虚拟环境(如venv、virtualenv)来隔离项目依赖。
|
|
|
|
|
|
|
|
|
|
# docs/source/conf.py
|
|
|
|
|
# Sphinx配置文件
|
|
|
|
|
|
|
|
|
|
project = '深度学习项目'
|
|
|
|
|
copyright = '2024, 你的团队'
|
|
|
|
|
author = '你的团队'
|
|
|
|
|
|
|
|
|
|
extensions = [
|
|
|
|
|
'sphinx.ext.autodoc',
|
|
|
|
|
'sphinx.ext.napoleon',
|
|
|
|
|
'sphinx.ext.viewcode',
|
|
|
|
|
'sphinx.ext.mathjax',
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
autodoc_mock_imports = ['torch', 'numpy', 'pandas']
|
|
|
|
|
|
|
|
|
|
# Napoleon设置
|
|
|
|
|
napoleon_google_docstring = True
|
|
|
|
|
napoleon_numpy_docstring = True
|
|
|
|
|
napoleon_include_init_with_doc = True
|
|
|
|
|
|
|
|
|
|
8. 工具和自动化
|
|
|
|
|
●使用Travis CI、Jenkins、GitHub Actions等工具设置持续集成,自动运行测试和代码质量检查。
|
|
|
|
|
●使用Python内置的logging模块进行日志记录,避免使用print语句。
|
|
|
|
|
|
|
|
|
|
8.1 预提交钩子配置
|
|
|
|
|
# .pre-commit-config.yaml
|
|
|
|
|
repos:
|
|
|
|
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
|
|
|
rev: v4.4.0
|
|
|
|
|
hooks:
|
|
|
|
|
- id: trailing-whitespace
|
|
|
|
|
- id: end-of-file-fixer
|
|
|
|
|
- id: check-yaml
|
|
|
|
|
- id: check-added-large-files
|
|
|
|
|
|
|
|
|
|
- repo: https://github.com/psf/black
|
|
|
|
|
rev: 23.3.0
|
|
|
|
|
hooks:
|
|
|
|
|
- id: black
|
|
|
|
|
language_version: python3.9
|
|
|
|
|
|
|
|
|
|
- repo: https://github.com/pycqa/isort
|
|
|
|
|
rev: 5.12.0
|
|
|
|
|
hooks:
|
|
|
|
|
- id: isort
|
|
|
|
|
args: ["--profile", "black"]
|
|
|
|
|
|
|
|
|
|
- repo: https://github.com/pycqa/flake8
|
|
|
|
|
rev: 6.0.0
|
|
|
|
|
hooks:
|
|
|
|
|
- id: flake8
|
|
|
|
|
args: ["--max-line-length=88", "--extend-ignore=E203,W503"]
|
|
|
|
|
|
|
|
|
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
|
|
|
rev: v1.3.0
|
|
|
|
|
hooks:
|
|
|
|
|
- id: mypy
|
|
|
|
|
additional_dependencies: [types-PyYAML]
|
|
|
|
|
|
|
|
|
|
8.2 Makefile自动化
|
|
|
|
|
# Makefile
|
|
|
|
|
.PHONY: help install test lint format clean
|
|
|
|
|
|
|
|
|
|
help:
|
|
|
|
|
@echo "可用命令:"
|
|
|
|
|
@echo " install 安装开发依赖"
|
|
|
|
|
@echo " test 运行测试"
|
|
|
|
|
@echo " lint 代码检查"
|
|
|
|
|
@echo " format 代码格式化"
|
|
|
|
|
@echo " clean 清理临时文件"
|
|
|
|
|
|
|
|
|
|
install:
|
|
|
|
|
pip install -r requirements.txt
|
|
|
|
|
pip install -e .
|
|
|
|
|
|
|
|
|
|
test:
|
|
|
|
|
pytest tests/ -v --cov=src --cov-report=html
|
|
|
|
|
|
|
|
|
|
lint:
|
|
|
|
|
flake8 src/ tests/
|
|
|
|
|
mypy src/
|
|
|
|
|
|
|
|
|
|
format:
|
|
|
|
|
black src/ tests/
|
|
|
|
|
isort src/ tests/
|
|
|
|
|
|
|
|
|
|
clean:
|
|
|
|
|
find . -type f -name "*.pyc" -delete
|
|
|
|
|
find . -type d -name "__pycache__" -delete
|
|
|
|
|
rm -rf .coverage htmlcov build dist
|
|
|
|
|
|
|
|
|
|
# 训练相关命令
|
|
|
|
|
train:
|
|
|
|
|
python scripts/train.py --config configs/default.yaml
|
|
|
|
|
|
|
|
|
|
evaluate:
|
|
|
|
|
python scripts/evaluate.py --model checkpoints/best_model.pth
|
|
|
|
|
|
|
|
|
|
总结
|
|
|
|
|
遵循这些Python项目代码规范可以带来以下好处:
|
|
|
|
|
1. 可维护性:清晰的代码结构和规范使项目易于理解和维护
|
|
|
|
|
2. 可复现性:完整的配置管理和实验跟踪确保结果可复现
|
|
|
|
|
3. 可扩展性:模块化设计便于添加新功能和模型
|
|
|
|
|
4. 协作效率:统一的规范降低团队协作成本
|
|
|
|
|
5. 代码质量:自动化工具和测试保证代码质量
|
|
|
|
|
建议在项目开始阶段就建立这些规范,并在开发过程中持续执行代码审查来确保规范的落实。
|