feat: 后端代码

master
chaol 3 months ago
parent d0c0d05107
commit 56a233bf72

@ -0,0 +1 @@
# 应用包初始化文件

@ -0,0 +1 @@
# API路由包

@ -0,0 +1,245 @@
"""
文件管理API路由
"""
import os
import mimetypes
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from typing import List
from app.database import get_db
from app.models.project import Project
import zipfile
import tempfile
import shutil
router = APIRouter()
@router.get("/projects/{project_id}/files")
async def get_project_files(
project_id: int,
path: str = "",
db: Session = Depends(get_db)
):
"""获取项目文件列表"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
if not project.project_path or not os.path.exists(project.project_path):
raise HTTPException(status_code=404, detail="项目路径不存在")
target_path = os.path.join(project.project_path, path) if path else project.project_path
if not os.path.exists(target_path):
raise HTTPException(status_code=404, detail="路径不存在")
if not os.path.isdir(target_path):
raise HTTPException(status_code=400, detail="路径不是目录")
files = []
directories = []
try:
for item in os.listdir(target_path):
item_path = os.path.join(target_path, item)
# 跳过隐藏文件和常见的非代码文件
if item.startswith('.') or item in ['node_modules', '__pycache__', '.git', 'venv', 'env']:
continue
item_info = {
'name': item,
'path': os.path.join(path, item).replace('\\', '/'),
'is_directory': os.path.isdir(item_path),
'size': os.path.getsize(item_path) if os.path.isfile(item_path) else 0,
'modified': os.path.getmtime(item_path)
}
if item_info['is_directory']:
directories.append(item_info)
else:
# 只显示代码文件
if _is_code_file(item):
files.append(item_info)
# 排序:目录在前,文件在后,按名称排序
directories.sort(key=lambda x: x['name'].lower())
files.sort(key=lambda x: x['name'].lower())
return {
'files': directories + files,
'current_path': path,
'project_path': project.project_path
}
except PermissionError:
raise HTTPException(status_code=403, detail="权限不足")
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取文件列表失败: {str(e)}")
@router.get("/projects/{project_id}/files/content")
async def get_file_content(
project_id: int,
file_path: str,
db: Session = Depends(get_db)
):
"""获取文件内容"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
full_path = os.path.join(project.project_path, file_path)
if not os.path.exists(full_path) or not os.path.isfile(full_path):
raise HTTPException(status_code=404, detail="文件不存在")
if not _is_code_file(full_path):
raise HTTPException(status_code=400, detail="不支持的文件类型")
try:
# 尝试不同的编码
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
content = None
for encoding in encodings:
try:
with open(full_path, 'r', encoding=encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
if content is None:
raise HTTPException(status_code=400, detail="文件编码不支持")
return {
'content': content,
'file_path': file_path,
'size': len(content.encode('utf-8')),
'lines': len(content.splitlines())
}
except PermissionError:
raise HTTPException(status_code=403, detail="权限不足")
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}")
@router.post("/projects/{project_id}/files/content")
async def save_file_content(
project_id: int,
file_path: str,
content: str,
db: Session = Depends(get_db)
):
"""保存文件内容"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
full_path = os.path.join(project.project_path, file_path)
if not os.path.exists(full_path) or not os.path.isfile(full_path):
raise HTTPException(status_code=404, detail="文件不存在")
try:
# 备份原文件
backup_path = f"{full_path}.backup"
shutil.copy2(full_path, backup_path)
# 写入新内容
with open(full_path, 'w', encoding='utf-8') as f:
f.write(content)
return {"message": "文件保存成功", "backup": backup_path}
except PermissionError:
raise HTTPException(status_code=403, detail="权限不足")
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}")
@router.post("/projects/{project_id}/files/upload")
async def upload_files(
project_id: int,
files: List[UploadFile] = File(...),
db: Session = Depends(get_db)
):
"""上传文件到项目"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
uploaded_files = []
try:
for file in files:
if not _is_code_file(file.filename):
continue
file_path = os.path.join(project.project_path, file.filename)
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
uploaded_files.append({
'filename': file.filename,
'size': len(content)
})
return {"message": f"成功上传 {len(uploaded_files)} 个文件", "files": uploaded_files}
except Exception as e:
raise HTTPException(status_code=500, detail=f"上传文件失败: {str(e)}")
@router.get("/projects/{project_id}/files/download")
async def download_project(
project_id: int,
db: Session = Depends(get_db)
):
"""下载整个项目为ZIP文件"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
if not os.path.exists(project.project_path):
raise HTTPException(status_code=404, detail="项目路径不存在")
try:
# 创建临时ZIP文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
with zipfile.ZipFile(temp_file.name, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(project.project_path):
# 跳过不必要的目录
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__']]
for file in files:
if _is_code_file(file):
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, project.project_path)
zipf.write(file_path, arcname)
return FileResponse(
temp_file.name,
media_type='application/zip',
filename=f"{project.name}.zip"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"创建下载文件失败: {str(e)}")
def _is_code_file(filename: str) -> bool:
"""判断是否为代码文件"""
code_extensions = {
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp',
'.cs', '.php', '.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.sh',
'.sql', '.html', '.css', '.scss', '.sass', '.less', '.vue', '.json',
'.xml', '.yaml', '.yml', '.md', '.txt'
}
if isinstance(filename, str):
_, ext = os.path.splitext(filename.lower())
return ext in code_extensions
return False

@ -0,0 +1,77 @@
"""
项目管理API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from app.database import get_db
from app.models.project import Project
from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse
router = APIRouter()
@router.get("/", response_model=List[ProjectResponse])
async def get_projects(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
"""获取项目列表"""
projects = db.query(Project).filter(Project.is_active == True).offset(skip).limit(limit).all()
return projects
@router.post("/", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
async def create_project(
project: ProjectCreate,
db: Session = Depends(get_db)
):
"""创建新项目"""
db_project = Project(**project.dict())
db.add(db_project)
db.commit()
db.refresh(db_project)
return db_project
@router.get("/{project_id}", response_model=ProjectResponse)
async def get_project(
project_id: int,
db: Session = Depends(get_db)
):
"""获取项目详情"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
return project
@router.put("/{project_id}", response_model=ProjectResponse)
async def update_project(
project_id: int,
project_update: ProjectUpdate,
db: Session = Depends(get_db)
):
"""更新项目信息"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
update_data = project_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(project, field, value)
db.commit()
db.refresh(project)
return project
@router.delete("/{project_id}")
async def delete_project(
project_id: int,
db: Session = Depends(get_db)
):
"""删除项目(软删除)"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
project.is_active = False
db.commit()
return {"message": "项目已删除"}

@ -0,0 +1,100 @@
"""
报告生成API路由
"""
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from typing import Optional
from app.database import get_db
from app.models.scan import Scan
from app.models.project import Project
from app.services.report_service import ReportService
router = APIRouter()
@router.get("/scan/{scan_id}")
async def generate_scan_report(
scan_id: int,
format: str = "html", # html, pdf, json, excel
db: Session = Depends(get_db)
):
"""生成扫描报告"""
scan = db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise HTTPException(status_code=404, detail="扫描任务不存在")
if scan.status.value != "completed":
raise HTTPException(status_code=400, detail="扫描未完成,无法生成报告")
report_service = ReportService()
if format == "html":
report_path = await report_service.generate_html_report(scan)
return FileResponse(
report_path,
media_type="text/html",
filename=f"scan_report_{scan_id}.html"
)
elif format == "pdf":
report_path = await report_service.generate_pdf_report(scan)
return FileResponse(
report_path,
media_type="application/pdf",
filename=f"scan_report_{scan_id}.pdf"
)
elif format == "json":
report_data = await report_service.generate_json_report(scan)
return report_data
elif format == "excel":
report_path = await report_service.generate_excel_report(scan)
return FileResponse(
report_path,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
filename=f"scan_report_{scan_id}.xlsx"
)
else:
raise HTTPException(status_code=400, detail="不支持的报告格式")
@router.get("/project/{project_id}")
async def generate_project_report(
project_id: int,
format: str = "html",
db: Session = Depends(get_db)
):
"""生成项目汇总报告"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取项目最新的扫描结果
latest_scan = db.query(Scan).filter(
Scan.project_id == project_id,
Scan.status == "completed"
).order_by(Scan.completed_at.desc()).first()
if not latest_scan:
raise HTTPException(status_code=404, detail="项目没有完成的扫描记录")
report_service = ReportService()
if format == "html":
report_path = await report_service.generate_project_html_report(project, latest_scan)
return FileResponse(
report_path,
media_type="text/html",
filename=f"project_report_{project_id}.html"
)
elif format == "json":
report_data = await report_service.generate_project_json_report(project, latest_scan)
return report_data
else:
raise HTTPException(status_code=400, detail="项目报告暂不支持此格式")
@router.get("/dashboard/summary")
async def get_dashboard_summary(db: Session = Depends(get_db)):
"""获取仪表板汇总数据"""
from app.services.dashboard_service import DashboardService
dashboard_service = DashboardService()
summary = await dashboard_service.get_summary_data(db)
return summary

@ -0,0 +1,108 @@
"""
扫描管理API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List
from app.database import get_db
from app.models.scan import Scan, ScanStatus, ScanType
from app.models.project import Project
from app.schemas.scan import ScanCreate, ScanResponse, ScanStatusResponse
from app.services.scan_service import ScanService
router = APIRouter()
@router.get("/", response_model=List[ScanResponse])
async def get_scans(
project_id: int = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
"""获取扫描历史"""
query = db.query(Scan)
if project_id:
query = query.filter(Scan.project_id == project_id)
scans = query.order_by(Scan.created_at.desc()).offset(skip).limit(limit).all()
return scans
@router.post("/", response_model=ScanResponse, status_code=status.HTTP_201_CREATED)
async def create_scan(
scan_data: ScanCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
"""创建并启动扫描任务"""
# 验证项目存在
project = db.query(Project).filter(Project.id == scan_data.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 创建扫描记录
scan = Scan(
project_id=scan_data.project_id,
scan_type=scan_data.scan_type,
scan_config=scan_data.scan_config,
status=ScanStatus.PENDING
)
db.add(scan)
db.commit()
db.refresh(scan)
# 启动后台扫描任务
background_tasks.add_task(ScanService.run_scan, scan.id)
return scan
@router.get("/{scan_id}", response_model=ScanResponse)
async def get_scan(
scan_id: int,
db: Session = Depends(get_db)
):
"""获取扫描详情"""
scan = db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise HTTPException(status_code=404, detail="扫描任务不存在")
return scan
@router.get("/{scan_id}/status", response_model=ScanStatusResponse)
async def get_scan_status(
scan_id: int,
db: Session = Depends(get_db)
):
"""获取扫描状态"""
scan = db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise HTTPException(status_code=404, detail="扫描任务不存在")
return {
"scan_id": scan.id,
"status": scan.status.value,
"progress": {
"total_files": scan.total_files,
"scanned_files": scan.scanned_files,
"percentage": (scan.scanned_files / scan.total_files * 100) if scan.total_files > 0 else 0
},
"started_at": scan.started_at,
"completed_at": scan.completed_at,
"error_message": scan.error_message
}
@router.post("/{scan_id}/cancel")
async def cancel_scan(
scan_id: int,
db: Session = Depends(get_db)
):
"""取消扫描任务"""
scan = db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
raise HTTPException(status_code=404, detail="扫描任务不存在")
if scan.status in [ScanStatus.COMPLETED, ScanStatus.FAILED, ScanStatus.CANCELLED]:
raise HTTPException(status_code=400, detail="扫描任务已完成或已取消")
scan.status = ScanStatus.CANCELLED
db.commit()
return {"message": "扫描任务已取消"}

@ -0,0 +1,118 @@
"""
漏洞管理API路由
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from typing import List, Optional
from app.database import get_db
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory
from app.schemas.vulnerability import VulnerabilityResponse, VulnerabilityUpdate
router = APIRouter()
@router.get("/", response_model=List[VulnerabilityResponse])
async def get_vulnerabilities(
scan_id: Optional[int] = None,
project_id: Optional[int] = None,
severity: Optional[SeverityLevel] = None,
category: Optional[VulnerabilityCategory] = None,
status: Optional[VulnerabilityStatus] = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
"""获取漏洞列表"""
query = db.query(Vulnerability)
# 应用过滤条件
if scan_id:
query = query.filter(Vulnerability.scan_id == scan_id)
if severity:
query = query.filter(Vulnerability.severity == severity)
if category:
query = query.filter(Vulnerability.category == category)
if status:
query = query.filter(Vulnerability.status == status)
vulnerabilities = query.order_by(
Vulnerability.severity.desc(),
Vulnerability.line_number
).offset(skip).limit(limit).all()
return vulnerabilities
@router.get("/{vulnerability_id}", response_model=VulnerabilityResponse)
async def get_vulnerability(
vulnerability_id: int,
db: Session = Depends(get_db)
):
"""获取漏洞详情"""
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
if not vulnerability:
raise HTTPException(status_code=404, detail="漏洞不存在")
return vulnerability
@router.put("/{vulnerability_id}", response_model=VulnerabilityResponse)
async def update_vulnerability(
vulnerability_id: int,
vulnerability_update: VulnerabilityUpdate,
db: Session = Depends(get_db)
):
"""更新漏洞状态"""
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
if not vulnerability:
raise HTTPException(status_code=404, detail="漏洞不存在")
update_data = vulnerability_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(vulnerability, field, value)
db.commit()
db.refresh(vulnerability)
return vulnerability
@router.get("/stats/summary")
async def get_vulnerability_stats(
scan_id: Optional[int] = None,
project_id: Optional[int] = None,
db: Session = Depends(get_db)
):
"""获取漏洞统计信息"""
query = db.query(Vulnerability)
if scan_id:
query = query.filter(Vulnerability.scan_id == scan_id)
elif project_id:
# 通过项目ID获取最新的扫描ID
from app.models.scan import Scan
latest_scan = db.query(Scan).filter(
Scan.project_id == project_id,
Scan.status == "completed"
).order_by(Scan.completed_at.desc()).first()
if latest_scan:
query = query.filter(Vulnerability.scan_id == latest_scan.id)
else:
return {"total": 0, "by_severity": {}, "by_category": {}}
vulnerabilities = query.all()
# 统计信息
total = len(vulnerabilities)
by_severity = {}
by_category = {}
for vuln in vulnerabilities:
# 按严重程度统计
severity = vuln.severity.value
by_severity[severity] = by_severity.get(severity, 0) + 1
# 按分类统计
category = vuln.category.value
by_category[category] = by_category.get(category, 0) + 1
return {
"total": total,
"by_severity": by_severity,
"by_category": by_category
}

@ -0,0 +1 @@
# 核心模块包

@ -0,0 +1,78 @@
"""
分析器基类
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import os
import glob
class BaseAnalyzer(ABC):
"""分析器基类"""
def __init__(self):
self.name = "Base Analyzer"
self.version = "1.0.0"
self.supported_extensions = []
self.description = "基础分析器"
@abstractmethod
async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""
分析项目代码
Args:
project_path: 项目路径
config: 分析配置
Returns:
漏洞列表
"""
pass
def get_project_files(self, project_path: str) -> List[str]:
"""获取项目中的所有文件"""
files = []
for ext in self.supported_extensions:
pattern = os.path.join(project_path, "**", f"*.{ext}")
files.extend(glob.glob(pattern, recursive=True))
return files
def read_file_content(self, file_path: str) -> str:
"""读取文件内容"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
# 如果UTF-8解码失败尝试其他编码
try:
with open(file_path, 'r', encoding='gbk') as f:
return f.read()
except:
return ""
except Exception:
return ""
def create_vulnerability(
self,
rule_id: str,
message: str,
file_path: str,
line_number: int = None,
severity: str = "medium",
category: str = "maintainability",
code_snippet: str = "",
context_before: str = "",
context_after: str = ""
) -> Dict[str, Any]:
"""创建漏洞对象"""
return {
'rule_id': rule_id,
'message': message,
'file_path': file_path,
'line_number': line_number,
'severity': severity,
'category': category,
'code_snippet': code_snippet,
'context_before': context_before,
'context_after': context_after
}

@ -0,0 +1,187 @@
"""
C++代码分析器
"""
import re
import os
from typing import List, Dict, Any
from .base_analyzer import BaseAnalyzer
class CppAnalyzer(BaseAnalyzer):
"""C++代码分析器"""
def __init__(self):
super().__init__()
self.name = "C++ Analyzer"
self.version = "1.0.0"
self.supported_extensions = ["cpp", "cc", "cxx", "hpp", "h"]
self.description = "C++代码静态分析器"
async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""分析C++代码"""
vulnerabilities = []
# 获取所有C++文件
cpp_files = self.get_project_files(project_path)
for file_path in cpp_files:
try:
# 读取文件内容
content = self.read_file_content(file_path)
if not content:
continue
lines = content.split('\n')
# 执行各种检查
vulnerabilities.extend(self._check_memory_issues(lines, file_path))
vulnerabilities.extend(self._check_security_issues(lines, file_path))
vulnerabilities.extend(self._check_performance_issues(lines, file_path))
except Exception as e:
# 分析错误
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP000",
message=f"分析错误: {str(e)}",
file_path=file_path,
severity="high",
category="reliability"
))
return vulnerabilities
def _check_memory_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]:
"""检查内存相关问题"""
vulnerabilities = []
for i, line in enumerate(lines):
line_num = i + 1
line_stripped = line.strip()
# 检查裸指针使用
if re.search(r'\*[a-zA-Z_][a-zA-Z0-9_]*\s*[=;]', line_stripped):
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP101",
message="使用裸指针,建议使用智能指针",
file_path=file_path,
line_number=line_num,
severity="medium",
category="reliability",
code_snippet=line_stripped
))
# 检查malloc/free使用
elif 'malloc(' in line_stripped or 'free(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP102",
message="使用malloc/free建议使用new/delete或智能指针",
file_path=file_path,
line_number=line_num,
severity="medium",
category="reliability",
code_snippet=line_stripped
))
# 检查数组越界风险
elif re.search(r'\[[a-zA-Z_][a-zA-Z0-9_]*\]', line_stripped):
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP103",
message="数组访问未进行边界检查",
file_path=file_path,
line_number=line_num,
severity="high",
category="reliability",
code_snippet=line_stripped
))
return vulnerabilities
def _check_security_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]:
"""检查安全问题"""
vulnerabilities = []
for i, line in enumerate(lines):
line_num = i + 1
line_stripped = line.strip()
# 检查strcpy使用
if 'strcpy(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP201",
message="使用strcpy(),存在缓冲区溢出风险",
file_path=file_path,
line_number=line_num,
severity="critical",
category="security",
code_snippet=line_stripped
))
# 检查sprintf使用
elif 'sprintf(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP202",
message="使用sprintf(),存在缓冲区溢出风险",
file_path=file_path,
line_number=line_num,
severity="critical",
category="security",
code_snippet=line_stripped
))
# 检查gets使用
elif 'gets(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP203",
message="使用gets(),存在缓冲区溢出风险",
file_path=file_path,
line_number=line_num,
severity="critical",
category="security",
code_snippet=line_stripped
))
# 检查system调用
elif 'system(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP204",
message="使用system()调用,存在命令注入风险",
file_path=file_path,
line_number=line_num,
severity="high",
category="security",
code_snippet=line_stripped
))
return vulnerabilities
def _check_performance_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]:
"""检查性能问题"""
vulnerabilities = []
for i, line in enumerate(lines):
line_num = i + 1
line_stripped = line.strip()
# 检查循环中的字符串连接
if re.search(r'for\s*\([^)]*\)\s*\{', line_stripped):
# 查找循环体中的字符串连接
for j in range(i + 1, min(i + 20, len(lines))): # 检查循环体前20行
loop_line = lines[j].strip()
if loop_line == '}':
break
if '+' in loop_line and ('"' in loop_line or "'" in loop_line):
vulnerabilities.append(self.create_vulnerability(
rule_id="CPP301",
message="循环中使用字符串连接,影响性能",
file_path=file_path,
line_number=j + 1,
severity="low",
category="performance",
code_snippet=loop_line
))
# 检查未使用的头文件包含
elif line_stripped.startswith('#include'):
# 这里可以添加更复杂的检查逻辑
pass
return vulnerabilities

@ -0,0 +1,233 @@
"""
JavaScript代码分析器
"""
import re
import os
from typing import List, Dict, Any
from .base_analyzer import BaseAnalyzer
class JavaScriptAnalyzer(BaseAnalyzer):
"""JavaScript代码分析器"""
def __init__(self):
super().__init__()
self.name = "JavaScript Analyzer"
self.version = "1.0.0"
self.supported_extensions = ["js", "jsx", "ts", "tsx"]
self.description = "JavaScript/TypeScript代码静态分析器"
async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""分析JavaScript代码"""
vulnerabilities = []
# 获取所有JavaScript文件
js_files = self.get_project_files(project_path)
for file_path in js_files:
try:
# 读取文件内容
content = self.read_file_content(file_path)
if not content:
continue
lines = content.split('\n')
# 执行各种检查
vulnerabilities.extend(self._check_security_issues(lines, file_path))
vulnerabilities.extend(self._check_performance_issues(lines, file_path))
vulnerabilities.extend(self._check_maintainability_issues(lines, file_path))
except Exception as e:
# 分析错误
vulnerabilities.append(self.create_vulnerability(
rule_id="JS000",
message=f"分析错误: {str(e)}",
file_path=file_path,
severity="high",
category="reliability"
))
return vulnerabilities
def _check_security_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]:
"""检查安全问题"""
vulnerabilities = []
for i, line in enumerate(lines):
line_num = i + 1
line_stripped = line.strip()
# 检查eval使用
if 'eval(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="JS101",
message="使用eval(),存在代码注入风险",
file_path=file_path,
line_number=line_num,
severity="critical",
category="security",
code_snippet=line_stripped
))
# 检查innerHTML使用
elif 'innerHTML' in line_stripped and '=' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="JS102",
message="使用innerHTML存在XSS风险",
file_path=file_path,
line_number=line_num,
severity="high",
category="security",
code_snippet=line_stripped
))
# 检查document.write使用
elif 'document.write(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="JS103",
message="使用document.write()存在XSS风险",
file_path=file_path,
line_number=line_num,
severity="high",
category="security",
code_snippet=line_stripped
))
# 检查console.log在生产环境中的使用
elif 'console.log(' in line_stripped:
vulnerabilities.append(self.create_vulnerability(
rule_id="JS104",
message="生产环境中不应使用console.log",
file_path=file_path,
line_number=line_num,
severity="low",
category="security",
code_snippet=line_stripped
))
# 检查硬编码的敏感信息
elif re.search(r'password\s*[:=]\s*["\'][^"\']+["\']', line_stripped, re.IGNORECASE):
vulnerabilities.append(self.create_vulnerability(
rule_id="JS105",
message="代码中包含硬编码的密码",
file_path=file_path,
line_number=line_num,
severity="high",
category="security",
code_snippet=line_stripped
))
return vulnerabilities
def _check_performance_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]:
"""检查性能问题"""
vulnerabilities = []
for i, line in enumerate(lines):
line_num = i + 1
line_stripped = line.strip()
# 检查DOM操作在循环中
if re.search(r'for\s*\([^)]*\)\s*\{', line_stripped):
# 查找循环体中的DOM操作
for j in range(i + 1, min(i + 20, len(lines))):
loop_line = lines[j].strip()
if loop_line == '}':
break
if any(dom_op in loop_line for dom_op in ['getElementById', 'querySelector', 'appendChild']):
vulnerabilities.append(self.create_vulnerability(
rule_id="JS201",
message="循环中进行DOM操作影响性能",
file_path=file_path,
line_number=j + 1,
severity="medium",
category="performance",
code_snippet=loop_line
))
# 检查未使用的变量声明
elif line_stripped.startswith('var ') or line_stripped.startswith('let ') or line_stripped.startswith('const '):
var_name = re.search(r'(var|let|const)\s+([a-zA-Z_$][a-zA-Z0-9_$]*)', line_stripped)
if var_name:
var_name = var_name.group(2)
# 检查变量是否在后续代码中使用
is_used = False
for k in range(i + 1, len(lines)):
if var_name in lines[k]:
is_used = True
break
if not is_used:
vulnerabilities.append(self.create_vulnerability(
rule_id="JS202",
message=f"未使用的变量: {var_name}",
file_path=file_path,
line_number=line_num,
severity="low",
category="performance",
code_snippet=line_stripped
))
return vulnerabilities
def _check_maintainability_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]:
"""检查可维护性问题"""
vulnerabilities = []
for i, line in enumerate(lines):
line_num = i + 1
line_stripped = line.strip()
# 检查函数长度(简单检查)
if line_stripped.startswith('function ') or re.match(r'const\s+\w+\s*=\s*\([^)]*\)\s*=>', line_stripped):
# 计算函数体长度
brace_count = 0
function_start = i
for j in range(i, len(lines)):
line_content = lines[j]
brace_count += line_content.count('{')
brace_count -= line_content.count('}')
if brace_count > 0 and j > i:
continue
elif brace_count == 0 and j > i:
function_length = j - function_start
if function_length > 30:
vulnerabilities.append(self.create_vulnerability(
rule_id="JS301",
message="函数过长,建议拆分",
file_path=file_path,
line_number=line_num,
severity="medium",
category="maintainability",
code_snippet=line_stripped
))
break
# 检查深度嵌套
elif line_stripped.endswith('{'):
indent_level = len(line) - len(line.lstrip())
if indent_level > 40: # 假设每层缩进4个空格
vulnerabilities.append(self.create_vulnerability(
rule_id="JS302",
message="代码嵌套过深,影响可读性",
file_path=file_path,
line_number=line_num,
severity="low",
category="maintainability",
code_snippet=line_stripped
))
# 检查魔法数字
elif re.search(r'\b\d{3,}\b', line_stripped):
vulnerabilities.append(self.create_vulnerability(
rule_id="JS303",
message="存在魔法数字,建议使用常量",
file_path=file_path,
line_number=line_num,
severity="low",
category="maintainability",
code_snippet=line_stripped
))
return vulnerabilities

@ -0,0 +1,193 @@
"""
Python代码分析器
"""
import ast
import os
from typing import List, Dict, Any
from .base_analyzer import BaseAnalyzer
class PythonAnalyzer(BaseAnalyzer):
"""Python代码分析器"""
def __init__(self):
super().__init__()
self.name = "Python Analyzer"
self.version = "1.0.0"
self.supported_extensions = ["py"]
self.description = "Python代码静态分析器"
async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""分析Python代码"""
vulnerabilities = []
# 获取所有Python文件
python_files = self.get_project_files(project_path)
for file_path in python_files:
try:
# 读取文件内容
content = self.read_file_content(file_path)
if not content:
continue
# 解析AST
tree = ast.parse(content, filename=file_path)
# 执行各种检查
vulnerabilities.extend(self._check_security_issues(tree, file_path, content))
vulnerabilities.extend(self._check_performance_issues(tree, file_path, content))
vulnerabilities.extend(self._check_maintainability_issues(tree, file_path, content))
except SyntaxError as e:
# 语法错误
vulnerabilities.append(self.create_vulnerability(
rule_id="PY001",
message=f"语法错误: {str(e)}",
file_path=file_path,
line_number=e.lineno,
severity="critical",
category="reliability"
))
except Exception as e:
# 其他错误
vulnerabilities.append(self.create_vulnerability(
rule_id="PY000",
message=f"分析错误: {str(e)}",
file_path=file_path,
severity="high",
category="reliability"
))
return vulnerabilities
def _check_security_issues(self, tree: ast.AST, file_path: str, content: str) -> List[Dict[str, Any]]:
"""检查安全问题"""
vulnerabilities = []
lines = content.split('\n')
for node in ast.walk(tree):
# 检查eval使用
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'eval':
vulnerabilities.append(self.create_vulnerability(
rule_id="PY101",
message="使用了eval()函数,存在代码注入风险",
file_path=file_path,
line_number=node.lineno,
severity="critical",
category="security",
code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else ""
))
# 检查exec使用
elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'exec':
vulnerabilities.append(self.create_vulnerability(
rule_id="PY102",
message="使用了exec()函数,存在代码注入风险",
file_path=file_path,
line_number=node.lineno,
severity="critical",
category="security",
code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else ""
))
# 检查pickle使用
elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if (isinstance(node.func.value, ast.Name) and
node.func.value.id == 'pickle' and
node.func.attr in ['loads', 'load']):
vulnerabilities.append(self.create_vulnerability(
rule_id="PY103",
message="使用了pickle反序列化存在安全风险",
file_path=file_path,
line_number=node.lineno,
severity="high",
category="security",
code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else ""
))
return vulnerabilities
def _check_performance_issues(self, tree: ast.AST, file_path: str, content: str) -> List[Dict[str, Any]]:
"""检查性能问题"""
vulnerabilities = []
lines = content.split('\n')
for node in ast.walk(tree):
# 检查列表推导式中的循环
if isinstance(node, ast.ListComp):
if len(node.generators) > 1:
vulnerabilities.append(self.create_vulnerability(
rule_id="PY201",
message="复杂的列表推导式可能影响性能",
file_path=file_path,
line_number=node.lineno,
severity="medium",
category="performance",
code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else ""
))
# 检查全局变量使用
elif isinstance(node, ast.Global):
vulnerabilities.append(self.create_vulnerability(
rule_id="PY202",
message="使用全局变量可能影响性能",
file_path=file_path,
line_number=node.lineno,
severity="low",
category="performance",
code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else ""
))
return vulnerabilities
def _check_maintainability_issues(self, tree: ast.AST, file_path: str, content: str) -> List[Dict[str, Any]]:
"""检查可维护性问题"""
vulnerabilities = []
lines = content.split('\n')
# 检查函数长度
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
if len(node.body) > 20:
vulnerabilities.append(self.create_vulnerability(
rule_id="PY301",
message=f"函数 '{node.name}' 过长,建议拆分",
file_path=file_path,
line_number=node.lineno,
severity="medium",
category="maintainability",
code_snippet=f"def {node.name}(...):"
))
# 检查类长度
elif isinstance(node, ast.ClassDef):
if len(node.body) > 15:
vulnerabilities.append(self.create_vulnerability(
rule_id="PY302",
message=f"'{node.name}' 过长,建议拆分",
file_path=file_path,
line_number=node.lineno,
severity="medium",
category="maintainability",
code_snippet=f"class {node.name}:"
))
# 检查循环嵌套
elif isinstance(node, (ast.For, ast.While)):
nested_loops = 0
for child in ast.walk(node):
if isinstance(child, (ast.For, ast.While)) and child != node:
nested_loops += 1
if nested_loops > 2:
vulnerabilities.append(self.create_vulnerability(
rule_id="PY303",
message="循环嵌套过深,影响代码可读性",
file_path=file_path,
line_number=node.lineno,
severity="low",
category="maintainability",
code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else ""
))
return vulnerabilities

@ -0,0 +1,35 @@
"""
数据库配置和连接管理
"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os
# 数据库URL配置
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./code_scanner.db")
# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基础模型类
Base = declarative_base()
def get_db():
"""获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
"""初始化数据库表"""
from app.models import project, scan, vulnerability
Base.metadata.create_all(bind=engine)

@ -0,0 +1,4 @@
# 数据模型包
from .project import Project
from .scan import Scan
from .vulnerability import Vulnerability

@ -0,0 +1,32 @@
"""
项目数据模型
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.database import Base
class Project(Base):
"""项目模型"""
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, index=True)
description = Column(Text)
language = Column(String(20), nullable=False) # Python, C++, JavaScript等
repository_url = Column(String(500))
project_path = Column(String(500)) # 本地项目路径
config = Column(Text) # JSON格式的配置信息
# 状态字段
is_active = Column(Boolean, default=True)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# 关联关系
scans = relationship("Scan", back_populates="project", cascade="all, delete-orphan")
def __repr__(self):
return f"<Project(id={self.id}, name='{self.name}', language='{self.language}')>"

@ -0,0 +1,57 @@
"""
扫描任务数据模型
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, ForeignKey, Enum
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
import enum
from app.database import Base
class ScanStatus(enum.Enum):
"""扫描状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class ScanType(enum.Enum):
"""扫描类型枚举"""
FULL = "full" # 全量扫描
INCREMENTAL = "incremental" # 增量扫描
CUSTOM = "custom" # 自定义扫描
class Scan(Base):
"""扫描任务模型"""
__tablename__ = "scans"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
# 扫描配置
scan_type = Column(Enum(ScanType), default=ScanType.FULL)
scan_config = Column(Text) # JSON格式的扫描配置
# 扫描状态
status = Column(Enum(ScanStatus), default=ScanStatus.PENDING)
# 扫描统计
total_files = Column(Integer, default=0)
scanned_files = Column(Integer, default=0)
total_vulnerabilities = Column(Integer, default=0)
# 时间戳
started_at = Column(DateTime(timezone=True))
completed_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
# 结果信息
result_summary = Column(Text) # JSON格式的结果摘要
error_message = Column(Text) # 错误信息
# 关联关系
project = relationship("Project", back_populates="scans")
vulnerabilities = relationship("Vulnerability", back_populates="scan", cascade="all, delete-orphan")
def __repr__(self):
return f"<Scan(id={self.id}, project_id={self.project_id}, status='{self.status.value}')>"

@ -0,0 +1,77 @@
"""
漏洞数据模型
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, ForeignKey, Enum, Float
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
import enum
from app.database import Base
class SeverityLevel(enum.Enum):
"""严重程度枚举"""
CRITICAL = "critical"
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
INFO = "info"
class VulnerabilityCategory(enum.Enum):
"""漏洞分类枚举"""
SECURITY = "security"
PERFORMANCE = "performance"
MAINTAINABILITY = "maintainability"
RELIABILITY = "reliability"
USABILITY = "usability"
class VulnerabilityStatus(enum.Enum):
"""漏洞状态枚举"""
OPEN = "open"
FIXED = "fixed"
FALSE_POSITIVE = "false_positive"
WONT_FIX = "wont_fix"
class Vulnerability(Base):
"""漏洞模型"""
__tablename__ = "vulnerabilities"
id = Column(Integer, primary_key=True, index=True)
scan_id = Column(Integer, ForeignKey("scans.id"), nullable=False)
# 漏洞基本信息
rule_id = Column(String(100), nullable=False) # 规则ID
message = Column(Text, nullable=False) # 漏洞描述
category = Column(Enum(VulnerabilityCategory), nullable=False)
severity = Column(Enum(SeverityLevel), nullable=False)
# 位置信息
file_path = Column(String(500), nullable=False)
line_number = Column(Integer)
column_number = Column(Integer)
end_line = Column(Integer)
end_column = Column(Integer)
# 代码上下文
code_snippet = Column(Text) # 相关代码片段
context_before = Column(Text) # 前置代码上下文
context_after = Column(Text) # 后置代码上下文
# AI增强信息
ai_enhanced = Column(Boolean, default=False)
ai_confidence = Column(Float) # AI置信度 0-1
ai_suggestion = Column(Text) # AI修复建议
# 状态管理
status = Column(Enum(VulnerabilityStatus), default=VulnerabilityStatus.OPEN)
assigned_to = Column(String(100)) # 分配给谁
fix_commit = Column(String(100)) # 修复的提交哈希
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
fixed_at = Column(DateTime(timezone=True))
# 关联关系
scan = relationship("Scan", back_populates="vulnerabilities")
def __repr__(self):
return f"<Vulnerability(id={self.id}, rule_id='{self.rule_id}', severity='{self.severity.value}')>"

@ -0,0 +1 @@
# 数据模式包

@ -0,0 +1,37 @@
"""
项目相关的Pydantic模式
"""
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
class ProjectBase(BaseModel):
"""项目基础模式"""
name: str
description: Optional[str] = None
language: str
repository_url: Optional[str] = None
project_path: Optional[str] = None
config: Optional[str] = None
class ProjectCreate(ProjectBase):
"""创建项目模式"""
pass
class ProjectUpdate(BaseModel):
"""更新项目模式"""
name: Optional[str] = None
description: Optional[str] = None
repository_url: Optional[str] = None
project_path: Optional[str] = None
config: Optional[str] = None
class ProjectResponse(ProjectBase):
"""项目响应模式"""
id: int
is_active: bool
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True

@ -0,0 +1,42 @@
"""
扫描相关的Pydantic模式
"""
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
from app.models.scan import ScanStatus, ScanType
class ScanBase(BaseModel):
"""扫描基础模式"""
project_id: int
scan_type: ScanType = ScanType.FULL
scan_config: Optional[str] = None
class ScanCreate(ScanBase):
"""创建扫描模式"""
pass
class ScanResponse(ScanBase):
"""扫描响应模式"""
id: int
status: ScanStatus
total_files: int
scanned_files: int
total_vulnerabilities: int
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
created_at: datetime
result_summary: Optional[str] = None
error_message: Optional[str] = None
class Config:
from_attributes = True
class ScanStatusResponse(BaseModel):
"""扫描状态响应模式"""
scan_id: int
status: str
progress: dict
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
error_message: Optional[str] = None

@ -0,0 +1,49 @@
"""
漏洞相关的Pydantic模式
"""
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
from app.models.vulnerability import SeverityLevel, VulnerabilityCategory, VulnerabilityStatus
class VulnerabilityBase(BaseModel):
"""漏洞基础模式"""
rule_id: str
message: str
category: VulnerabilityCategory
severity: SeverityLevel
file_path: str
line_number: Optional[int] = None
column_number: Optional[int] = None
end_line: Optional[int] = None
end_column: Optional[int] = None
code_snippet: Optional[str] = None
context_before: Optional[str] = None
context_after: Optional[str] = None
ai_enhanced: bool = False
ai_confidence: Optional[float] = None
ai_suggestion: Optional[str] = None
class VulnerabilityCreate(VulnerabilityBase):
"""创建漏洞模式"""
scan_id: int
class VulnerabilityUpdate(BaseModel):
"""更新漏洞模式"""
status: Optional[VulnerabilityStatus] = None
assigned_to: Optional[str] = None
fix_commit: Optional[str] = None
class VulnerabilityResponse(VulnerabilityBase):
"""漏洞响应模式"""
id: int
scan_id: int
status: VulnerabilityStatus
assigned_to: Optional[str] = None
fix_commit: Optional[str] = None
created_at: datetime
updated_at: Optional[datetime] = None
fixed_at: Optional[datetime] = None
class Config:
from_attributes = True

@ -0,0 +1,137 @@
"""
AI增强服务 - 基于现有的DeepSeek集成
"""
import requests
import json
import time
from typing import Dict, Any, List
class AIService:
"""AI增强服务"""
def __init__(self):
# 从环境变量或配置文件读取API配置
self.api_url = "https://api.deepseek.com/v1/chat/completions"
self.api_key = "your_deepseek_api_key_here" # 实际使用时从环境变量获取
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
async def enhance_vulnerability(self, vulnerability: Dict[str, Any]) -> Dict[str, Any]:
"""AI增强漏洞分析"""
try:
# 构建AI分析提示
prompt = self._build_enhancement_prompt(vulnerability)
# 调用AI API
ai_response = await self._call_ai_api(prompt)
# 解析AI响应
enhancement = self._parse_ai_response(ai_response)
return {
'ai_enhanced': True,
'ai_confidence': enhancement.get('confidence', 0.8),
'ai_suggestion': enhancement.get('suggestion', ''),
'ai_explanation': enhancement.get('explanation', '')
}
except Exception as e:
print(f"AI增强失败: {str(e)}")
return {
'ai_enhanced': False,
'ai_confidence': 0.0,
'ai_suggestion': '',
'ai_explanation': f'AI分析失败: {str(e)}'
}
def _build_enhancement_prompt(self, vulnerability: Dict[str, Any]) -> str:
"""构建AI分析提示"""
prompt = f"""
请分析以下代码漏洞并提供详细的修复建议
漏洞信息
- 规则ID: {vulnerability.get('rule_id', 'N/A')}
- 严重程度: {vulnerability.get('severity', 'N/A')}
- 分类: {vulnerability.get('category', 'N/A')}
- 文件路径: {vulnerability.get('file_path', 'N/A')}
- 行号: {vulnerability.get('line_number', 'N/A')}
- 描述: {vulnerability.get('message', 'N/A')}
相关代码
```{vulnerability.get('language', 'text')}
{vulnerability.get('code_snippet', '')}
```
请提供
1. 漏洞的详细解释
2. 可能的修复方案
3. 修复后的代码示例
4. 预防类似问题的最佳实践
请以JSON格式回复包含以下字段
- explanation: 详细解释
- suggestion: 修复建议
- fixed_code: 修复后的代码示例
- best_practices: 最佳实践建议
- confidence: 分析置信度(0-1)
"""
return prompt
async def _call_ai_api(self, prompt: str) -> str:
"""调用AI API"""
data = {
"model": "deepseek-chat",
"messages": [
{"role": "system", "content": "你是一个专业的代码安全分析专家。"},
{"role": "user", "content": prompt}
],
"temperature": 0.3,
"max_tokens": 2000
}
response = requests.post(self.api_url, headers=self.headers, json=data)
response.raise_for_status()
result = response.json()
return result['choices'][0]['message']['content']
def _parse_ai_response(self, response: str) -> Dict[str, Any]:
"""解析AI响应"""
try:
# 尝试解析JSON响应
if response.strip().startswith('{'):
return json.loads(response)
# 如果不是JSON返回原始响应
return {
'explanation': response,
'suggestion': '',
'fixed_code': '',
'best_practices': '',
'confidence': 0.7
}
except json.JSONDecodeError:
return {
'explanation': response,
'suggestion': '',
'fixed_code': '',
'best_practices': '',
'confidence': 0.7
}
async def batch_enhance_vulnerabilities(self, vulnerabilities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""批量AI增强漏洞"""
enhanced_vulnerabilities = []
for vulnerability in vulnerabilities:
enhancement = await self.enhance_vulnerability(vulnerability)
vulnerability.update(enhancement)
enhanced_vulnerabilities.append(vulnerability)
# 避免API请求过快
time.sleep(0.5)
return enhanced_vulnerabilities

@ -0,0 +1,60 @@
"""
代码分析服务
"""
import os
import json
import subprocess
from typing import List, Dict, Any
from app.core.analyzers.base_analyzer import BaseAnalyzer
from app.core.analyzers.python_analyzer import PythonAnalyzer
from app.core.analyzers.cpp_analyzer import CppAnalyzer
from app.core.analyzers.javascript_analyzer import JavaScriptAnalyzer
from app.services.ai_service import AIService
class AnalyzerService:
"""代码分析服务"""
def __init__(self):
self.analyzers = {
'python': PythonAnalyzer(),
'cpp': CppAnalyzer(),
'javascript': JavaScriptAnalyzer(),
}
self.ai_service = AIService()
async def analyze_project(self, project_path: str, language: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""分析项目代码"""
if language not in self.analyzers:
raise ValueError(f"不支持的语言: {language}")
analyzer = self.analyzers[language]
# 运行静态分析
static_results = await analyzer.analyze(project_path, config)
# AI增强分析
ai_enhanced_results = []
for result in static_results:
# 对每个漏洞进行AI增强
ai_enhancement = await self.ai_service.enhance_vulnerability(result)
result.update(ai_enhancement)
ai_enhanced_results.append(result)
return ai_enhanced_results
def get_supported_languages(self) -> List[str]:
"""获取支持的语言列表"""
return list(self.analyzers.keys())
def get_analyzer_info(self, language: str) -> Dict[str, Any]:
"""获取分析器信息"""
if language not in self.analyzers:
return None
analyzer = self.analyzers[language]
return {
'name': analyzer.name,
'version': analyzer.version,
'supported_extensions': analyzer.supported_extensions,
'description': analyzer.description
}

@ -0,0 +1,120 @@
"""
仪表板服务
"""
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from app.models.project import Project
from app.models.scan import Scan
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel
class DashboardService:
"""仪表板服务"""
async def get_summary_data(self, db: Session) -> dict:
"""获取仪表板汇总数据"""
# 项目统计
total_projects = db.query(Project).filter(Project.is_active == True).count()
# 扫描统计
total_scans = db.query(Scan).count()
completed_scans = db.query(Scan).filter(Scan.status == "completed").count()
# 漏洞统计
total_vulnerabilities = db.query(Vulnerability).count()
fixed_vulnerabilities = db.query(Vulnerability).filter(
Vulnerability.status == VulnerabilityStatus.FIXED
).count()
# 按严重程度统计
severity_stats = db.query(
Vulnerability.severity,
func.count(Vulnerability.id).label('count')
).group_by(Vulnerability.severity).all()
severity_summary = {}
for severity, count in severity_stats:
severity_summary[severity.value] = count
# 最近发现的漏洞
recent_vulnerabilities = db.query(Vulnerability).order_by(
desc(Vulnerability.created_at)
).limit(10).all()
recent_vuln_data = []
for vuln in recent_vulnerabilities:
recent_vuln_data.append({
'id': vuln.id,
'project_name': vuln.scan.project.name if vuln.scan and vuln.scan.project else 'Unknown',
'category': vuln.category.value,
'severity': vuln.severity.value,
'file_path': vuln.file_path,
'message': vuln.message,
'created_at': vuln.created_at.isoformat()
})
# 项目漏洞统计
project_stats = db.query(
Project.name,
func.count(Vulnerability.id).label('vulnerability_count')
).join(Scan, Project.id == Scan.project_id).join(
Vulnerability, Scan.id == Vulnerability.scan_id
).group_by(Project.id, Project.name).order_by(
desc('vulnerability_count')
).limit(5).all()
project_vuln_data = []
for project_name, count in project_stats:
project_vuln_data.append({
'project_name': project_name,
'vulnerability_count': count
})
return {
'projects': total_projects,
'scans': total_scans,
'completed_scans': completed_scans,
'vulnerabilities': total_vulnerabilities,
'fixed': fixed_vulnerabilities,
'severity_summary': severity_summary,
'recent_vulnerabilities': recent_vuln_data,
'project_stats': project_vuln_data
}
async def get_trend_data(self, db: Session, days: int = 30) -> dict:
"""获取趋势数据"""
from datetime import datetime, timedelta
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
# 每日扫描统计
daily_scans = db.query(
func.date(Scan.created_at).label('date'),
func.count(Scan.id).label('count')
).filter(
Scan.created_at >= start_date
).group_by(func.date(Scan.created_at)).all()
# 每日漏洞统计
daily_vulnerabilities = db.query(
func.date(Vulnerability.created_at).label('date'),
func.count(Vulnerability.id).label('count')
).filter(
Vulnerability.created_at >= start_date
).group_by(func.date(Vulnerability.created_at)).all()
return {
'daily_scans': [{'date': str(date), 'count': count} for date, count in daily_scans],
'daily_vulnerabilities': [{'date': str(date), 'count': count} for date, count in daily_vulnerabilities]
}
async def get_category_distribution(self, db: Session) -> dict:
"""获取漏洞分类分布"""
category_stats = db.query(
Vulnerability.category,
func.count(Vulnerability.id).label('count')
).group_by(Vulnerability.category).all()
return {
category.value: count for category, count in category_stats
}

@ -0,0 +1,220 @@
"""
报告生成服务
"""
import os
import json
import pandas as pd
from jinja2 import Template
from typing import Dict, Any
from datetime import datetime
from app.models.scan import Scan
from app.models.project import Project
class ReportService:
"""报告生成服务"""
def __init__(self):
self.templates_dir = "app/templates"
self.reports_dir = "reports"
os.makedirs(self.reports_dir, exist_ok=True)
async def generate_html_report(self, scan: Scan) -> str:
"""生成HTML报告"""
# 准备报告数据
report_data = await self._prepare_report_data(scan)
# 读取HTML模板
template_path = os.path.join(self.templates_dir, "scan_report.html")
with open(template_path, 'r', encoding='utf-8') as f:
template_content = f.read()
# 渲染模板
template = Template(template_content)
html_content = template.render(**report_data)
# 保存HTML文件
report_filename = f"scan_report_{scan.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html"
report_path = os.path.join(self.reports_dir, report_filename)
with open(report_path, 'w', encoding='utf-8') as f:
f.write(html_content)
return report_path
async def generate_pdf_report(self, scan: Scan) -> str:
"""生成PDF报告"""
# 先生成HTML报告
html_path = await self.generate_html_report(scan)
# 按需导入WeasyPrint并给出友好降级
try:
from weasyprint import HTML # type: ignore
except Exception as exc: # ImportError 或底层依赖缺失
raise RuntimeError(
"PDF 导出所需依赖缺失WeasyPrint 及其系统库)。" \
"请先使用 HTML/Excel/JSON 导出,或按安装指南配置 WeasyPrint。原始错误: " + str(exc)
)
# 转换为PDF
pdf_filename = f"scan_report_{scan.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"
pdf_path = os.path.join(self.reports_dir, pdf_filename)
HTML(filename=html_path).write_pdf(pdf_path)
return pdf_path
async def generate_json_report(self, scan: Scan) -> Dict[str, Any]:
"""生成JSON报告"""
report_data = await self._prepare_report_data(scan)
return report_data
async def generate_excel_report(self, scan: Scan) -> str:
"""生成Excel报告"""
# 获取漏洞数据
vulnerabilities = scan.vulnerabilities
# 准备Excel数据
excel_data = []
for vuln in vulnerabilities:
excel_data.append({
'ID': vuln.id,
'规则ID': vuln.rule_id,
'严重程度': vuln.severity.value,
'分类': vuln.category.value,
'文件路径': vuln.file_path,
'行号': vuln.line_number,
'描述': vuln.message,
'AI增强': '' if vuln.ai_enhanced else '',
'AI置信度': vuln.ai_confidence,
'AI建议': vuln.ai_suggestion,
'状态': vuln.status.value,
'创建时间': vuln.created_at.strftime('%Y-%m-%d %H:%M:%S')
})
# 创建DataFrame
df = pd.DataFrame(excel_data)
# 保存Excel文件
excel_filename = f"scan_report_{scan.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
excel_path = os.path.join(self.reports_dir, excel_filename)
with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
df.to_excel(writer, sheet_name='漏洞详情', index=False)
# 添加统计信息表
stats_data = await self._generate_stats_data(scan)
stats_df = pd.DataFrame(stats_data)
stats_df.to_excel(writer, sheet_name='统计信息', index=False)
return excel_path
async def generate_project_html_report(self, project: Project, latest_scan: Scan) -> str:
"""生成项目汇总报告"""
# 准备项目报告数据
report_data = await self._prepare_project_report_data(project, latest_scan)
# 读取项目报告模板
template_path = os.path.join(self.templates_dir, "project_report.html")
with open(template_path, 'r', encoding='utf-8') as f:
template_content = f.read()
# 渲染模板
template = Template(template_content)
html_content = template.render(**report_data)
# 保存HTML文件
report_filename = f"project_report_{project.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html"
report_path = os.path.join(self.reports_dir, report_filename)
with open(report_path, 'w', encoding='utf-8') as f:
f.write(html_content)
return report_path
async def generate_project_json_report(self, project: Project, latest_scan: Scan) -> Dict[str, Any]:
"""生成项目JSON报告"""
report_data = await self._prepare_project_report_data(project, latest_scan)
return report_data
async def _prepare_report_data(self, scan: Scan) -> Dict[str, Any]:
"""准备报告数据"""
vulnerabilities = scan.vulnerabilities
# 按严重程度分组
by_severity = {}
by_category = {}
for vuln in vulnerabilities:
severity = vuln.severity.value
category = vuln.category.value
if severity not in by_severity:
by_severity[severity] = []
by_severity[severity].append(vuln)
if category not in by_category:
by_category[category] = []
by_category[category].append(vuln)
return {
'scan': scan,
'project': scan.project,
'vulnerabilities': vulnerabilities,
'by_severity': by_severity,
'by_category': by_category,
'total_vulnerabilities': len(vulnerabilities),
'generated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
async def _prepare_project_report_data(self, project: Project, latest_scan: Scan) -> Dict[str, Any]:
"""准备项目报告数据"""
vulnerabilities = latest_scan.vulnerabilities
# 统计信息
total_vulnerabilities = len(vulnerabilities)
critical_count = len([v for v in vulnerabilities if v.severity.value == 'critical'])
high_count = len([v for v in vulnerabilities if v.severity.value == 'high'])
medium_count = len([v for v in vulnerabilities if v.severity.value == 'medium'])
low_count = len([v for v in vulnerabilities if v.severity.value == 'low'])
return {
'project': project,
'latest_scan': latest_scan,
'vulnerabilities': vulnerabilities,
'stats': {
'total': total_vulnerabilities,
'critical': critical_count,
'high': high_count,
'medium': medium_count,
'low': low_count
},
'generated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
async def _generate_stats_data(self, scan: Scan) -> list:
"""生成统计信息数据"""
vulnerabilities = scan.vulnerabilities
# 按严重程度统计
severity_stats = {}
for vuln in vulnerabilities:
severity = vuln.severity.value
severity_stats[severity] = severity_stats.get(severity, 0) + 1
# 按分类统计
category_stats = {}
for vuln in vulnerabilities:
category = vuln.category.value
category_stats[category] = category_stats.get(category, 0) + 1
stats_data = []
stats_data.append(['统计类型', '分类', '数量'])
stats_data.append(['严重程度', '总计', len(vulnerabilities)])
for severity, count in severity_stats.items():
stats_data.append(['严重程度', severity, count])
for category, count in category_stats.items():
stats_data.append(['分类', category, count])
return stats_data

@ -0,0 +1,162 @@
"""
扫描服务
"""
import asyncio
from typing import List
from sqlalchemy.orm import Session
from app.database import SessionLocal
from app.models.scan import Scan, ScanStatus
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory
from app.services.analyzer_service import AnalyzerService
from app.services.report_service import ReportService
import json
class ScanService:
"""扫描服务"""
def __init__(self):
self.analyzer_service = AnalyzerService()
self.report_service = ReportService()
async def run_scan(self, scan_id: int):
"""运行扫描任务"""
db = SessionLocal()
try:
# 获取扫描任务
scan = db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
print(f"扫描任务不存在: {scan_id}")
return
# 更新扫描状态为运行中
scan.status = ScanStatus.RUNNING
scan.started_at = asyncio.get_event_loop().time()
db.commit()
try:
# 获取项目信息
project = scan.project
if not project:
raise Exception("项目不存在")
# 解析扫描配置
scan_config = {}
if scan.scan_config:
scan_config = json.loads(scan.scan_config)
# 执行代码分析
vulnerabilities_data = await self.analyzer_service.analyze_project(
project_path=project.project_path,
language=project.language,
config=scan_config
)
# 保存漏洞数据到数据库
await self._save_vulnerabilities(db, scan, vulnerabilities_data)
# 更新扫描统计
scan.total_files = scan_config.get('total_files', 100) # 模拟文件数
scan.scanned_files = scan.total_files
scan.total_vulnerabilities = len(vulnerabilities_data)
scan.status = ScanStatus.COMPLETED
scan.completed_at = asyncio.get_event_loop().time()
# 生成结果摘要
summary = {
'total_vulnerabilities': scan.total_vulnerabilities,
'by_severity': {},
'by_category': {}
}
for vuln_data in vulnerabilities_data:
severity = vuln_data.get('severity', 'medium')
category = vuln_data.get('category', 'maintainability')
summary['by_severity'][severity] = summary['by_severity'].get(severity, 0) + 1
summary['by_category'][category] = summary['by_category'].get(category, 0) + 1
scan.result_summary = json.dumps(summary)
db.commit()
print(f"扫描完成: {scan_id}, 发现 {scan.total_vulnerabilities} 个漏洞")
except Exception as e:
# 扫描失败
scan.status = ScanStatus.FAILED
scan.error_message = str(e)
db.commit()
print(f"扫描失败: {scan_id}, 错误: {str(e)}")
finally:
db.close()
async def _save_vulnerabilities(self, db: Session, scan: Scan, vulnerabilities_data: List[dict]):
"""保存漏洞数据到数据库"""
for vuln_data in vulnerabilities_data:
# 映射严重程度
severity_mapping = {
'critical': SeverityLevel.CRITICAL,
'high': SeverityLevel.HIGH,
'medium': SeverityLevel.MEDIUM,
'low': SeverityLevel.LOW,
'info': SeverityLevel.INFO
}
# 映射分类
category_mapping = {
'security': VulnerabilityCategory.SECURITY,
'performance': VulnerabilityCategory.PERFORMANCE,
'maintainability': VulnerabilityCategory.MAINTAINABILITY,
'reliability': VulnerabilityCategory.RELIABILITY,
'usability': VulnerabilityCategory.USABILITY
}
vulnerability = Vulnerability(
scan_id=scan.id,
rule_id=vuln_data.get('rule_id', 'unknown'),
message=vuln_data.get('message', ''),
category=category_mapping.get(vuln_data.get('category', 'maintainability'), VulnerabilityCategory.MAINTAINABILITY),
severity=severity_mapping.get(vuln_data.get('severity', 'medium'), SeverityLevel.MEDIUM),
file_path=vuln_data.get('file_path', ''),
line_number=vuln_data.get('line_number'),
column_number=vuln_data.get('column_number'),
end_line=vuln_data.get('end_line'),
end_column=vuln_data.get('end_column'),
code_snippet=vuln_data.get('code_snippet', ''),
context_before=vuln_data.get('context_before', ''),
context_after=vuln_data.get('context_after', ''),
ai_enhanced=vuln_data.get('ai_enhanced', False),
ai_confidence=vuln_data.get('ai_confidence'),
ai_suggestion=vuln_data.get('ai_suggestion', ''),
status=VulnerabilityStatus.OPEN
)
db.add(vulnerability)
db.commit()
def get_scan_progress(self, scan_id: int) -> dict:
"""获取扫描进度"""
db = SessionLocal()
try:
scan = db.query(Scan).filter(Scan.id == scan_id).first()
if not scan:
return {'error': '扫描任务不存在'}
progress = 0
if scan.total_files > 0:
progress = (scan.scanned_files / scan.total_files) * 100
return {
'scan_id': scan_id,
'status': scan.status.value,
'progress': progress,
'total_files': scan.total_files,
'scanned_files': scan.scanned_files,
'total_vulnerabilities': scan.total_vulnerabilities,
'started_at': scan.started_at,
'completed_at': scan.completed_at,
'error_message': scan.error_message
}
finally:
db.close()

@ -0,0 +1,247 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>代码扫描报告 - {{ project.name }}</title>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
line-height: 1.6;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.header {
text-align: center;
border-bottom: 2px solid #1890ff;
padding-bottom: 20px;
margin-bottom: 30px;
}
.header h1 {
color: #1890ff;
margin: 0;
font-size: 2.5em;
}
.header p {
color: #666;
margin: 10px 0 0 0;
}
.summary {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.summary-card {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
text-align: center;
border-left: 4px solid #1890ff;
}
.summary-card h3 {
margin: 0 0 10px 0;
color: #333;
}
.summary-card .number {
font-size: 2em;
font-weight: bold;
color: #1890ff;
}
.severity-critical { border-left-color: #ff4d4f; }
.severity-high { border-left-color: #ff7a45; }
.severity-medium { border-left-color: #ffa940; }
.severity-low { border-left-color: #73d13d; }
.severity-info { border-left-color: #40a9ff; }
.section {
margin-bottom: 30px;
}
.section h2 {
color: #333;
border-bottom: 2px solid #f0f0f0;
padding-bottom: 10px;
}
.vulnerability {
background: #fff;
border: 1px solid #e8e8e8;
border-radius: 8px;
margin-bottom: 15px;
overflow: hidden;
}
.vulnerability-header {
background: #f8f9fa;
padding: 15px 20px;
border-bottom: 1px solid #e8e8e8;
display: flex;
justify-content: space-between;
align-items: center;
}
.vulnerability-title {
font-weight: bold;
font-size: 1.1em;
}
.severity-badge {
padding: 4px 12px;
border-radius: 20px;
color: white;
font-size: 0.9em;
font-weight: bold;
}
.severity-critical { background: #ff4d4f; }
.severity-high { background: #ff7a45; }
.severity-medium { background: #ffa940; }
.severity-low { background: #73d13d; }
.severity-info { background: #40a9ff; }
.vulnerability-body {
padding: 20px;
}
.vulnerability-meta {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 15px;
}
.meta-item {
display: flex;
align-items: center;
}
.meta-label {
font-weight: bold;
margin-right: 10px;
min-width: 80px;
}
.file-path {
font-family: 'Courier New', monospace;
background: #f5f5f5;
padding: 2px 6px;
border-radius: 4px;
}
.code-block {
background: #f8f8f8;
border: 1px solid #e8e8e8;
border-radius: 4px;
padding: 15px;
margin: 10px 0;
font-family: 'Courier New', monospace;
font-size: 0.9em;
overflow-x: auto;
}
.ai-suggestion {
background: #e6f7ff;
border: 1px solid #91d5ff;
border-radius: 4px;
padding: 15px;
margin-top: 10px;
}
.ai-suggestion h4 {
margin: 0 0 10px 0;
color: #1890ff;
}
.footer {
text-align: center;
margin-top: 40px;
padding-top: 20px;
border-top: 1px solid #e8e8e8;
color: #666;
}
@media print {
body { background: white; }
.container { box-shadow: none; }
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>代码扫描报告</h1>
<p>项目: {{ project.name }} | 生成时间: {{ generated_at }}</p>
</div>
<!-- 扫描摘要 -->
<div class="section">
<h2>扫描摘要</h2>
<div class="summary">
<div class="summary-card">
<h3>总漏洞数</h3>
<div class="number">{{ total_vulnerabilities }}</div>
</div>
{% for severity, vulns in by_severity.items() %}
<div class="summary-card severity-{{ severity }}">
<h3>{{ severity|title }} 漏洞</h3>
<div class="number">{{ vulns|length }}</div>
</div>
{% endfor %}
</div>
</div>
<!-- 漏洞详情 -->
<div class="section">
<h2>漏洞详情</h2>
{% for vulnerability in vulnerabilities %}
<div class="vulnerability">
<div class="vulnerability-header">
<div class="vulnerability-title">
{{ vulnerability.rule_id }}: {{ vulnerability.message }}
</div>
<span class="severity-badge severity-{{ vulnerability.severity.value }}">
{{ vulnerability.severity.value|upper }}
</span>
</div>
<div class="vulnerability-body">
<div class="vulnerability-meta">
<div class="meta-item">
<span class="meta-label">文件:</span>
<span class="file-path">{{ vulnerability.file_path }}</span>
</div>
<div class="meta-item">
<span class="meta-label">行号:</span>
<span>{{ vulnerability.line_number or 'N/A' }}</span>
</div>
<div class="meta-item">
<span class="meta-label">分类:</span>
<span>{{ vulnerability.category.value }}</span>
</div>
<div class="meta-item">
<span class="meta-label">状态:</span>
<span>{{ vulnerability.status.value }}</span>
</div>
</div>
{% if vulnerability.code_snippet %}
<div>
<strong>相关代码:</strong>
<div class="code-block">{{ vulnerability.code_snippet }}</div>
</div>
{% endif %}
{% if vulnerability.ai_enhanced and vulnerability.ai_suggestion %}
<div class="ai-suggestion">
<h4>🤖 AI 建议</h4>
<p>{{ vulnerability.ai_suggestion }}</p>
{% if vulnerability.ai_confidence %}
<small>置信度: {{ (vulnerability.ai_confidence * 100)|round(1) }}%</small>
{% endif %}
</div>
{% endif %}
</div>
</div>
{% endfor %}
</div>
<div class="footer">
<p>此报告由代码漏洞检测系统自动生成</p>
</div>
</div>
</body>
</html>

@ -0,0 +1,33 @@
"""
配置管理
"""
import os
from typing import Optional
class Config:
"""配置类"""
# 数据库配置
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./code_scanner.db")
# AI服务配置
DEEPSEEK_API_URL: str = os.getenv("DEEPSEEK_API_URL", "https://api.deepseek.com/v1/chat/completions")
DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY", "your_deepseek_api_key_here")
# 文件上传配置
UPLOAD_FOLDER: str = os.getenv("UPLOAD_FOLDER", "uploads")
MAX_CONTENT_LENGTH: int = int(os.getenv("MAX_CONTENT_LENGTH", "16 * 1024 * 1024")) # 16MB
# 扫描配置
MAX_SCAN_FILES: int = int(os.getenv("MAX_SCAN_FILES", "1000"))
SCAN_TIMEOUT: int = int(os.getenv("SCAN_TIMEOUT", "300")) # 5分钟
# 报告配置
REPORTS_FOLDER: str = os.getenv("REPORTS_FOLDER", "reports")
# 日志配置
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FILE: str = os.getenv("LOG_FILE", "app.log")
# 创建配置实例
config = Config()

@ -0,0 +1,55 @@
"""
代码漏洞检测系统 - 后端主启动文件
"""
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import projects, scans, reports, vulnerabilities, files
from app.database import init_db
# 创建FastAPI应用
app = FastAPI(
title="代码漏洞检测系统",
description="基于AI增强的代码漏洞检测和报告生成系统",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"], # 前端地址
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(projects.router, prefix="/api/projects", tags=["projects"])
app.include_router(scans.router, prefix="/api/scans", tags=["scans"])
app.include_router(reports.router, prefix="/api/reports", tags=["reports"])
app.include_router(vulnerabilities.router, prefix="/api/vulnerabilities", tags=["vulnerabilities"])
app.include_router(files.router, prefix="/api", tags=["files"])
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化数据库"""
init_db()
@app.get("/")
async def root():
"""根路径健康检查"""
return {"message": "代码漏洞检测系统 API", "status": "running"}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info"
)

@ -0,0 +1,12 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
sqlalchemy==2.0.23
pydantic==2.5.0
python-multipart==0.0.6
requests==2.31.0
python-dotenv==1.0.0
alembic==1.12.1
pandas==2.1.4
jinja2==3.1.2
weasyprint==60.2
openpyxl==3.1.2
Loading…
Cancel
Save