parent
d0c0d05107
commit
56a233bf72
@ -0,0 +1 @@
|
||||
# 应用包初始化文件
|
||||
@ -0,0 +1 @@
|
||||
# API路由包
|
||||
@ -0,0 +1,245 @@
|
||||
"""
|
||||
文件管理API路由
|
||||
"""
|
||||
import os
|
||||
import mimetypes
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
import zipfile
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/projects/{project_id}/files")
|
||||
async def get_project_files(
|
||||
project_id: int,
|
||||
path: str = "",
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取项目文件列表"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
if not project.project_path or not os.path.exists(project.project_path):
|
||||
raise HTTPException(status_code=404, detail="项目路径不存在")
|
||||
|
||||
target_path = os.path.join(project.project_path, path) if path else project.project_path
|
||||
|
||||
if not os.path.exists(target_path):
|
||||
raise HTTPException(status_code=404, detail="路径不存在")
|
||||
|
||||
if not os.path.isdir(target_path):
|
||||
raise HTTPException(status_code=400, detail="路径不是目录")
|
||||
|
||||
files = []
|
||||
directories = []
|
||||
|
||||
try:
|
||||
for item in os.listdir(target_path):
|
||||
item_path = os.path.join(target_path, item)
|
||||
|
||||
# 跳过隐藏文件和常见的非代码文件
|
||||
if item.startswith('.') or item in ['node_modules', '__pycache__', '.git', 'venv', 'env']:
|
||||
continue
|
||||
|
||||
item_info = {
|
||||
'name': item,
|
||||
'path': os.path.join(path, item).replace('\\', '/'),
|
||||
'is_directory': os.path.isdir(item_path),
|
||||
'size': os.path.getsize(item_path) if os.path.isfile(item_path) else 0,
|
||||
'modified': os.path.getmtime(item_path)
|
||||
}
|
||||
|
||||
if item_info['is_directory']:
|
||||
directories.append(item_info)
|
||||
else:
|
||||
# 只显示代码文件
|
||||
if _is_code_file(item):
|
||||
files.append(item_info)
|
||||
|
||||
# 排序:目录在前,文件在后,按名称排序
|
||||
directories.sort(key=lambda x: x['name'].lower())
|
||||
files.sort(key=lambda x: x['name'].lower())
|
||||
|
||||
return {
|
||||
'files': directories + files,
|
||||
'current_path': path,
|
||||
'project_path': project.project_path
|
||||
}
|
||||
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=403, detail="权限不足")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取文件列表失败: {str(e)}")
|
||||
|
||||
@router.get("/projects/{project_id}/files/content")
|
||||
async def get_file_content(
|
||||
project_id: int,
|
||||
file_path: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取文件内容"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
full_path = os.path.join(project.project_path, file_path)
|
||||
|
||||
if not os.path.exists(full_path) or not os.path.isfile(full_path):
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
if not _is_code_file(full_path):
|
||||
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
||||
|
||||
try:
|
||||
# 尝试不同的编码
|
||||
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
|
||||
content = None
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(full_path, 'r', encoding=encoding) as f:
|
||||
content = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
raise HTTPException(status_code=400, detail="文件编码不支持")
|
||||
|
||||
return {
|
||||
'content': content,
|
||||
'file_path': file_path,
|
||||
'size': len(content.encode('utf-8')),
|
||||
'lines': len(content.splitlines())
|
||||
}
|
||||
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=403, detail="权限不足")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}")
|
||||
|
||||
@router.post("/projects/{project_id}/files/content")
|
||||
async def save_file_content(
|
||||
project_id: int,
|
||||
file_path: str,
|
||||
content: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""保存文件内容"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
full_path = os.path.join(project.project_path, file_path)
|
||||
|
||||
if not os.path.exists(full_path) or not os.path.isfile(full_path):
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
try:
|
||||
# 备份原文件
|
||||
backup_path = f"{full_path}.backup"
|
||||
shutil.copy2(full_path, backup_path)
|
||||
|
||||
# 写入新内容
|
||||
with open(full_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
return {"message": "文件保存成功", "backup": backup_path}
|
||||
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=403, detail="权限不足")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}")
|
||||
|
||||
@router.post("/projects/{project_id}/files/upload")
|
||||
async def upload_files(
|
||||
project_id: int,
|
||||
files: List[UploadFile] = File(...),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""上传文件到项目"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
uploaded_files = []
|
||||
|
||||
try:
|
||||
for file in files:
|
||||
if not _is_code_file(file.filename):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(project.project_path, file.filename)
|
||||
|
||||
with open(file_path, "wb") as buffer:
|
||||
content = await file.read()
|
||||
buffer.write(content)
|
||||
|
||||
uploaded_files.append({
|
||||
'filename': file.filename,
|
||||
'size': len(content)
|
||||
})
|
||||
|
||||
return {"message": f"成功上传 {len(uploaded_files)} 个文件", "files": uploaded_files}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"上传文件失败: {str(e)}")
|
||||
|
||||
@router.get("/projects/{project_id}/files/download")
|
||||
async def download_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""下载整个项目为ZIP文件"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
if not os.path.exists(project.project_path):
|
||||
raise HTTPException(status_code=404, detail="项目路径不存在")
|
||||
|
||||
try:
|
||||
# 创建临时ZIP文件
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
|
||||
|
||||
with zipfile.ZipFile(temp_file.name, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(project.project_path):
|
||||
# 跳过不必要的目录
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__']]
|
||||
|
||||
for file in files:
|
||||
if _is_code_file(file):
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, project.project_path)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
return FileResponse(
|
||||
temp_file.name,
|
||||
media_type='application/zip',
|
||||
filename=f"{project.name}.zip"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"创建下载文件失败: {str(e)}")
|
||||
|
||||
def _is_code_file(filename: str) -> bool:
|
||||
"""判断是否为代码文件"""
|
||||
code_extensions = {
|
||||
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp',
|
||||
'.cs', '.php', '.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.sh',
|
||||
'.sql', '.html', '.css', '.scss', '.sass', '.less', '.vue', '.json',
|
||||
'.xml', '.yaml', '.yml', '.md', '.txt'
|
||||
}
|
||||
|
||||
if isinstance(filename, str):
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
return ext in code_extensions
|
||||
|
||||
return False
|
||||
@ -0,0 +1,77 @@
|
||||
"""
|
||||
项目管理API路由
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[ProjectResponse])
|
||||
async def get_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取项目列表"""
|
||||
projects = db.query(Project).filter(Project.is_active == True).offset(skip).limit(limit).all()
|
||||
return projects
|
||||
|
||||
@router.post("/", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_project(
|
||||
project: ProjectCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建新项目"""
|
||||
db_project = Project(**project.dict())
|
||||
db.add(db_project)
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
return db_project
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectResponse)
|
||||
async def get_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取项目详情"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
return project
|
||||
|
||||
@router.put("/{project_id}", response_model=ProjectResponse)
|
||||
async def update_project(
|
||||
project_id: int,
|
||||
project_update: ProjectUpdate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新项目信息"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
update_data = project_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(project, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return project
|
||||
|
||||
@router.delete("/{project_id}")
|
||||
async def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除项目(软删除)"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
project.is_active = False
|
||||
db.commit()
|
||||
return {"message": "项目已删除"}
|
||||
@ -0,0 +1,100 @@
|
||||
"""
|
||||
报告生成API路由
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from app.database import get_db
|
||||
from app.models.scan import Scan
|
||||
from app.models.project import Project
|
||||
from app.services.report_service import ReportService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/scan/{scan_id}")
|
||||
async def generate_scan_report(
|
||||
scan_id: int,
|
||||
format: str = "html", # html, pdf, json, excel
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""生成扫描报告"""
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="扫描任务不存在")
|
||||
|
||||
if scan.status.value != "completed":
|
||||
raise HTTPException(status_code=400, detail="扫描未完成,无法生成报告")
|
||||
|
||||
report_service = ReportService()
|
||||
|
||||
if format == "html":
|
||||
report_path = await report_service.generate_html_report(scan)
|
||||
return FileResponse(
|
||||
report_path,
|
||||
media_type="text/html",
|
||||
filename=f"scan_report_{scan_id}.html"
|
||||
)
|
||||
elif format == "pdf":
|
||||
report_path = await report_service.generate_pdf_report(scan)
|
||||
return FileResponse(
|
||||
report_path,
|
||||
media_type="application/pdf",
|
||||
filename=f"scan_report_{scan_id}.pdf"
|
||||
)
|
||||
elif format == "json":
|
||||
report_data = await report_service.generate_json_report(scan)
|
||||
return report_data
|
||||
elif format == "excel":
|
||||
report_path = await report_service.generate_excel_report(scan)
|
||||
return FileResponse(
|
||||
report_path,
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
filename=f"scan_report_{scan_id}.xlsx"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="不支持的报告格式")
|
||||
|
||||
@router.get("/project/{project_id}")
|
||||
async def generate_project_report(
|
||||
project_id: int,
|
||||
format: str = "html",
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""生成项目汇总报告"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 获取项目最新的扫描结果
|
||||
latest_scan = db.query(Scan).filter(
|
||||
Scan.project_id == project_id,
|
||||
Scan.status == "completed"
|
||||
).order_by(Scan.completed_at.desc()).first()
|
||||
|
||||
if not latest_scan:
|
||||
raise HTTPException(status_code=404, detail="项目没有完成的扫描记录")
|
||||
|
||||
report_service = ReportService()
|
||||
|
||||
if format == "html":
|
||||
report_path = await report_service.generate_project_html_report(project, latest_scan)
|
||||
return FileResponse(
|
||||
report_path,
|
||||
media_type="text/html",
|
||||
filename=f"project_report_{project_id}.html"
|
||||
)
|
||||
elif format == "json":
|
||||
report_data = await report_service.generate_project_json_report(project, latest_scan)
|
||||
return report_data
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="项目报告暂不支持此格式")
|
||||
|
||||
@router.get("/dashboard/summary")
|
||||
async def get_dashboard_summary(db: Session = Depends(get_db)):
|
||||
"""获取仪表板汇总数据"""
|
||||
from app.services.dashboard_service import DashboardService
|
||||
|
||||
dashboard_service = DashboardService()
|
||||
summary = await dashboard_service.get_summary_data(db)
|
||||
return summary
|
||||
@ -0,0 +1,108 @@
|
||||
"""
|
||||
扫描管理API路由
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from app.database import get_db
|
||||
from app.models.scan import Scan, ScanStatus, ScanType
|
||||
from app.models.project import Project
|
||||
from app.schemas.scan import ScanCreate, ScanResponse, ScanStatusResponse
|
||||
from app.services.scan_service import ScanService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[ScanResponse])
|
||||
async def get_scans(
|
||||
project_id: int = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取扫描历史"""
|
||||
query = db.query(Scan)
|
||||
if project_id:
|
||||
query = query.filter(Scan.project_id == project_id)
|
||||
|
||||
scans = query.order_by(Scan.created_at.desc()).offset(skip).limit(limit).all()
|
||||
return scans
|
||||
|
||||
@router.post("/", response_model=ScanResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_scan(
|
||||
scan_data: ScanCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建并启动扫描任务"""
|
||||
# 验证项目存在
|
||||
project = db.query(Project).filter(Project.id == scan_data.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 创建扫描记录
|
||||
scan = Scan(
|
||||
project_id=scan_data.project_id,
|
||||
scan_type=scan_data.scan_type,
|
||||
scan_config=scan_data.scan_config,
|
||||
status=ScanStatus.PENDING
|
||||
)
|
||||
db.add(scan)
|
||||
db.commit()
|
||||
db.refresh(scan)
|
||||
|
||||
# 启动后台扫描任务
|
||||
background_tasks.add_task(ScanService.run_scan, scan.id)
|
||||
|
||||
return scan
|
||||
|
||||
@router.get("/{scan_id}", response_model=ScanResponse)
|
||||
async def get_scan(
|
||||
scan_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取扫描详情"""
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="扫描任务不存在")
|
||||
return scan
|
||||
|
||||
@router.get("/{scan_id}/status", response_model=ScanStatusResponse)
|
||||
async def get_scan_status(
|
||||
scan_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取扫描状态"""
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="扫描任务不存在")
|
||||
|
||||
return {
|
||||
"scan_id": scan.id,
|
||||
"status": scan.status.value,
|
||||
"progress": {
|
||||
"total_files": scan.total_files,
|
||||
"scanned_files": scan.scanned_files,
|
||||
"percentage": (scan.scanned_files / scan.total_files * 100) if scan.total_files > 0 else 0
|
||||
},
|
||||
"started_at": scan.started_at,
|
||||
"completed_at": scan.completed_at,
|
||||
"error_message": scan.error_message
|
||||
}
|
||||
|
||||
@router.post("/{scan_id}/cancel")
|
||||
async def cancel_scan(
|
||||
scan_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""取消扫描任务"""
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise HTTPException(status_code=404, detail="扫描任务不存在")
|
||||
|
||||
if scan.status in [ScanStatus.COMPLETED, ScanStatus.FAILED, ScanStatus.CANCELLED]:
|
||||
raise HTTPException(status_code=400, detail="扫描任务已完成或已取消")
|
||||
|
||||
scan.status = ScanStatus.CANCELLED
|
||||
db.commit()
|
||||
|
||||
return {"message": "扫描任务已取消"}
|
||||
@ -0,0 +1,118 @@
|
||||
"""
|
||||
漏洞管理API路由
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from app.database import get_db
|
||||
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory
|
||||
from app.schemas.vulnerability import VulnerabilityResponse, VulnerabilityUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[VulnerabilityResponse])
|
||||
async def get_vulnerabilities(
|
||||
scan_id: Optional[int] = None,
|
||||
project_id: Optional[int] = None,
|
||||
severity: Optional[SeverityLevel] = None,
|
||||
category: Optional[VulnerabilityCategory] = None,
|
||||
status: Optional[VulnerabilityStatus] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取漏洞列表"""
|
||||
query = db.query(Vulnerability)
|
||||
|
||||
# 应用过滤条件
|
||||
if scan_id:
|
||||
query = query.filter(Vulnerability.scan_id == scan_id)
|
||||
if severity:
|
||||
query = query.filter(Vulnerability.severity == severity)
|
||||
if category:
|
||||
query = query.filter(Vulnerability.category == category)
|
||||
if status:
|
||||
query = query.filter(Vulnerability.status == status)
|
||||
|
||||
vulnerabilities = query.order_by(
|
||||
Vulnerability.severity.desc(),
|
||||
Vulnerability.line_number
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
@router.get("/{vulnerability_id}", response_model=VulnerabilityResponse)
|
||||
async def get_vulnerability(
|
||||
vulnerability_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取漏洞详情"""
|
||||
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
|
||||
if not vulnerability:
|
||||
raise HTTPException(status_code=404, detail="漏洞不存在")
|
||||
return vulnerability
|
||||
|
||||
@router.put("/{vulnerability_id}", response_model=VulnerabilityResponse)
|
||||
async def update_vulnerability(
|
||||
vulnerability_id: int,
|
||||
vulnerability_update: VulnerabilityUpdate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新漏洞状态"""
|
||||
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
|
||||
if not vulnerability:
|
||||
raise HTTPException(status_code=404, detail="漏洞不存在")
|
||||
|
||||
update_data = vulnerability_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(vulnerability, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(vulnerability)
|
||||
return vulnerability
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_vulnerability_stats(
|
||||
scan_id: Optional[int] = None,
|
||||
project_id: Optional[int] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取漏洞统计信息"""
|
||||
query = db.query(Vulnerability)
|
||||
|
||||
if scan_id:
|
||||
query = query.filter(Vulnerability.scan_id == scan_id)
|
||||
elif project_id:
|
||||
# 通过项目ID获取最新的扫描ID
|
||||
from app.models.scan import Scan
|
||||
latest_scan = db.query(Scan).filter(
|
||||
Scan.project_id == project_id,
|
||||
Scan.status == "completed"
|
||||
).order_by(Scan.completed_at.desc()).first()
|
||||
|
||||
if latest_scan:
|
||||
query = query.filter(Vulnerability.scan_id == latest_scan.id)
|
||||
else:
|
||||
return {"total": 0, "by_severity": {}, "by_category": {}}
|
||||
|
||||
vulnerabilities = query.all()
|
||||
|
||||
# 统计信息
|
||||
total = len(vulnerabilities)
|
||||
by_severity = {}
|
||||
by_category = {}
|
||||
|
||||
for vuln in vulnerabilities:
|
||||
# 按严重程度统计
|
||||
severity = vuln.severity.value
|
||||
by_severity[severity] = by_severity.get(severity, 0) + 1
|
||||
|
||||
# 按分类统计
|
||||
category = vuln.category.value
|
||||
by_category[category] = by_category.get(category, 0) + 1
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_severity": by_severity,
|
||||
"by_category": by_category
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
# 核心模块包
|
||||
@ -0,0 +1 @@
|
||||
# 分析器包
|
||||
@ -0,0 +1,35 @@
|
||||
"""
|
||||
数据库配置和连接管理
|
||||
"""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import os
|
||||
|
||||
# 数据库URL配置
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./code_scanner.db")
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 创建基础模型类
|
||||
Base = declarative_base()
|
||||
|
||||
def get_db():
|
||||
"""获取数据库会话"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库表"""
|
||||
from app.models import project, scan, vulnerability
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@ -0,0 +1,4 @@
|
||||
# 数据模型包
|
||||
from .project import Project
|
||||
from .scan import Scan
|
||||
from .vulnerability import Vulnerability
|
||||
@ -0,0 +1,32 @@
|
||||
"""
|
||||
项目数据模型
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
|
||||
class Project(Base):
|
||||
"""项目模型"""
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(Text)
|
||||
language = Column(String(20), nullable=False) # Python, C++, JavaScript等
|
||||
repository_url = Column(String(500))
|
||||
project_path = Column(String(500)) # 本地项目路径
|
||||
config = Column(Text) # JSON格式的配置信息
|
||||
|
||||
# 状态字段
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# 关联关系
|
||||
scans = relationship("Scan", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Project(id={self.id}, name='{self.name}', language='{self.language}')>"
|
||||
@ -0,0 +1,57 @@
|
||||
"""
|
||||
扫描任务数据模型
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
import enum
|
||||
from app.database import Base
|
||||
|
||||
class ScanStatus(enum.Enum):
|
||||
"""扫描状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class ScanType(enum.Enum):
|
||||
"""扫描类型枚举"""
|
||||
FULL = "full" # 全量扫描
|
||||
INCREMENTAL = "incremental" # 增量扫描
|
||||
CUSTOM = "custom" # 自定义扫描
|
||||
|
||||
class Scan(Base):
|
||||
"""扫描任务模型"""
|
||||
__tablename__ = "scans"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
|
||||
|
||||
# 扫描配置
|
||||
scan_type = Column(Enum(ScanType), default=ScanType.FULL)
|
||||
scan_config = Column(Text) # JSON格式的扫描配置
|
||||
|
||||
# 扫描状态
|
||||
status = Column(Enum(ScanStatus), default=ScanStatus.PENDING)
|
||||
|
||||
# 扫描统计
|
||||
total_files = Column(Integer, default=0)
|
||||
scanned_files = Column(Integer, default=0)
|
||||
total_vulnerabilities = Column(Integer, default=0)
|
||||
|
||||
# 时间戳
|
||||
started_at = Column(DateTime(timezone=True))
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# 结果信息
|
||||
result_summary = Column(Text) # JSON格式的结果摘要
|
||||
error_message = Column(Text) # 错误信息
|
||||
|
||||
# 关联关系
|
||||
project = relationship("Project", back_populates="scans")
|
||||
vulnerabilities = relationship("Vulnerability", back_populates="scan", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Scan(id={self.id}, project_id={self.project_id}, status='{self.status.value}')>"
|
||||
@ -0,0 +1,77 @@
|
||||
"""
|
||||
漏洞数据模型
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, ForeignKey, Enum, Float
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
import enum
|
||||
from app.database import Base
|
||||
|
||||
class SeverityLevel(enum.Enum):
|
||||
"""严重程度枚举"""
|
||||
CRITICAL = "critical"
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
INFO = "info"
|
||||
|
||||
class VulnerabilityCategory(enum.Enum):
|
||||
"""漏洞分类枚举"""
|
||||
SECURITY = "security"
|
||||
PERFORMANCE = "performance"
|
||||
MAINTAINABILITY = "maintainability"
|
||||
RELIABILITY = "reliability"
|
||||
USABILITY = "usability"
|
||||
|
||||
class VulnerabilityStatus(enum.Enum):
|
||||
"""漏洞状态枚举"""
|
||||
OPEN = "open"
|
||||
FIXED = "fixed"
|
||||
FALSE_POSITIVE = "false_positive"
|
||||
WONT_FIX = "wont_fix"
|
||||
|
||||
class Vulnerability(Base):
|
||||
"""漏洞模型"""
|
||||
__tablename__ = "vulnerabilities"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
scan_id = Column(Integer, ForeignKey("scans.id"), nullable=False)
|
||||
|
||||
# 漏洞基本信息
|
||||
rule_id = Column(String(100), nullable=False) # 规则ID
|
||||
message = Column(Text, nullable=False) # 漏洞描述
|
||||
category = Column(Enum(VulnerabilityCategory), nullable=False)
|
||||
severity = Column(Enum(SeverityLevel), nullable=False)
|
||||
|
||||
# 位置信息
|
||||
file_path = Column(String(500), nullable=False)
|
||||
line_number = Column(Integer)
|
||||
column_number = Column(Integer)
|
||||
end_line = Column(Integer)
|
||||
end_column = Column(Integer)
|
||||
|
||||
# 代码上下文
|
||||
code_snippet = Column(Text) # 相关代码片段
|
||||
context_before = Column(Text) # 前置代码上下文
|
||||
context_after = Column(Text) # 后置代码上下文
|
||||
|
||||
# AI增强信息
|
||||
ai_enhanced = Column(Boolean, default=False)
|
||||
ai_confidence = Column(Float) # AI置信度 0-1
|
||||
ai_suggestion = Column(Text) # AI修复建议
|
||||
|
||||
# 状态管理
|
||||
status = Column(Enum(VulnerabilityStatus), default=VulnerabilityStatus.OPEN)
|
||||
assigned_to = Column(String(100)) # 分配给谁
|
||||
fix_commit = Column(String(100)) # 修复的提交哈希
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
fixed_at = Column(DateTime(timezone=True))
|
||||
|
||||
# 关联关系
|
||||
scan = relationship("Scan", back_populates="vulnerabilities")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Vulnerability(id={self.id}, rule_id='{self.rule_id}', severity='{self.severity.value}')>"
|
||||
@ -0,0 +1 @@
|
||||
# 数据模式包
|
||||
@ -0,0 +1,37 @@
|
||||
"""
|
||||
项目相关的Pydantic模式
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""项目基础模式"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
language: str
|
||||
repository_url: Optional[str] = None
|
||||
project_path: Optional[str] = None
|
||||
config: Optional[str] = None
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""创建项目模式"""
|
||||
pass
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""更新项目模式"""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
repository_url: Optional[str] = None
|
||||
project_path: Optional[str] = None
|
||||
config: Optional[str] = None
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
"""项目响应模式"""
|
||||
id: int
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@ -0,0 +1,42 @@
|
||||
"""
|
||||
扫描相关的Pydantic模式
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from app.models.scan import ScanStatus, ScanType
|
||||
|
||||
class ScanBase(BaseModel):
|
||||
"""扫描基础模式"""
|
||||
project_id: int
|
||||
scan_type: ScanType = ScanType.FULL
|
||||
scan_config: Optional[str] = None
|
||||
|
||||
class ScanCreate(ScanBase):
|
||||
"""创建扫描模式"""
|
||||
pass
|
||||
|
||||
class ScanResponse(ScanBase):
|
||||
"""扫描响应模式"""
|
||||
id: int
|
||||
status: ScanStatus
|
||||
total_files: int
|
||||
scanned_files: int
|
||||
total_vulnerabilities: int
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
result_summary: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class ScanStatusResponse(BaseModel):
|
||||
"""扫描状态响应模式"""
|
||||
scan_id: int
|
||||
status: str
|
||||
progress: dict
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
@ -0,0 +1,49 @@
|
||||
"""
|
||||
漏洞相关的Pydantic模式
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from app.models.vulnerability import SeverityLevel, VulnerabilityCategory, VulnerabilityStatus
|
||||
|
||||
class VulnerabilityBase(BaseModel):
|
||||
"""漏洞基础模式"""
|
||||
rule_id: str
|
||||
message: str
|
||||
category: VulnerabilityCategory
|
||||
severity: SeverityLevel
|
||||
file_path: str
|
||||
line_number: Optional[int] = None
|
||||
column_number: Optional[int] = None
|
||||
end_line: Optional[int] = None
|
||||
end_column: Optional[int] = None
|
||||
code_snippet: Optional[str] = None
|
||||
context_before: Optional[str] = None
|
||||
context_after: Optional[str] = None
|
||||
ai_enhanced: bool = False
|
||||
ai_confidence: Optional[float] = None
|
||||
ai_suggestion: Optional[str] = None
|
||||
|
||||
class VulnerabilityCreate(VulnerabilityBase):
|
||||
"""创建漏洞模式"""
|
||||
scan_id: int
|
||||
|
||||
class VulnerabilityUpdate(BaseModel):
|
||||
"""更新漏洞模式"""
|
||||
status: Optional[VulnerabilityStatus] = None
|
||||
assigned_to: Optional[str] = None
|
||||
fix_commit: Optional[str] = None
|
||||
|
||||
class VulnerabilityResponse(VulnerabilityBase):
|
||||
"""漏洞响应模式"""
|
||||
id: int
|
||||
scan_id: int
|
||||
status: VulnerabilityStatus
|
||||
assigned_to: Optional[str] = None
|
||||
fix_commit: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
fixed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@ -0,0 +1 @@
|
||||
# 服务层包
|
||||
@ -0,0 +1,60 @@
|
||||
"""
|
||||
代码分析服务
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
from typing import List, Dict, Any
|
||||
from app.core.analyzers.base_analyzer import BaseAnalyzer
|
||||
from app.core.analyzers.python_analyzer import PythonAnalyzer
|
||||
from app.core.analyzers.cpp_analyzer import CppAnalyzer
|
||||
from app.core.analyzers.javascript_analyzer import JavaScriptAnalyzer
|
||||
from app.services.ai_service import AIService
|
||||
|
||||
class AnalyzerService:
|
||||
"""代码分析服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.analyzers = {
|
||||
'python': PythonAnalyzer(),
|
||||
'cpp': CppAnalyzer(),
|
||||
'javascript': JavaScriptAnalyzer(),
|
||||
}
|
||||
self.ai_service = AIService()
|
||||
|
||||
async def analyze_project(self, project_path: str, language: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||||
"""分析项目代码"""
|
||||
if language not in self.analyzers:
|
||||
raise ValueError(f"不支持的语言: {language}")
|
||||
|
||||
analyzer = self.analyzers[language]
|
||||
|
||||
# 运行静态分析
|
||||
static_results = await analyzer.analyze(project_path, config)
|
||||
|
||||
# AI增强分析
|
||||
ai_enhanced_results = []
|
||||
for result in static_results:
|
||||
# 对每个漏洞进行AI增强
|
||||
ai_enhancement = await self.ai_service.enhance_vulnerability(result)
|
||||
result.update(ai_enhancement)
|
||||
ai_enhanced_results.append(result)
|
||||
|
||||
return ai_enhanced_results
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""获取支持的语言列表"""
|
||||
return list(self.analyzers.keys())
|
||||
|
||||
def get_analyzer_info(self, language: str) -> Dict[str, Any]:
|
||||
"""获取分析器信息"""
|
||||
if language not in self.analyzers:
|
||||
return None
|
||||
|
||||
analyzer = self.analyzers[language]
|
||||
return {
|
||||
'name': analyzer.name,
|
||||
'version': analyzer.version,
|
||||
'supported_extensions': analyzer.supported_extensions,
|
||||
'description': analyzer.description
|
||||
}
|
||||
@ -0,0 +1,120 @@
|
||||
"""
|
||||
仪表板服务
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
from app.models.project import Project
|
||||
from app.models.scan import Scan
|
||||
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel
|
||||
|
||||
class DashboardService:
|
||||
"""仪表板服务"""
|
||||
|
||||
async def get_summary_data(self, db: Session) -> dict:
|
||||
"""获取仪表板汇总数据"""
|
||||
# 项目统计
|
||||
total_projects = db.query(Project).filter(Project.is_active == True).count()
|
||||
|
||||
# 扫描统计
|
||||
total_scans = db.query(Scan).count()
|
||||
completed_scans = db.query(Scan).filter(Scan.status == "completed").count()
|
||||
|
||||
# 漏洞统计
|
||||
total_vulnerabilities = db.query(Vulnerability).count()
|
||||
fixed_vulnerabilities = db.query(Vulnerability).filter(
|
||||
Vulnerability.status == VulnerabilityStatus.FIXED
|
||||
).count()
|
||||
|
||||
# 按严重程度统计
|
||||
severity_stats = db.query(
|
||||
Vulnerability.severity,
|
||||
func.count(Vulnerability.id).label('count')
|
||||
).group_by(Vulnerability.severity).all()
|
||||
|
||||
severity_summary = {}
|
||||
for severity, count in severity_stats:
|
||||
severity_summary[severity.value] = count
|
||||
|
||||
# 最近发现的漏洞
|
||||
recent_vulnerabilities = db.query(Vulnerability).order_by(
|
||||
desc(Vulnerability.created_at)
|
||||
).limit(10).all()
|
||||
|
||||
recent_vuln_data = []
|
||||
for vuln in recent_vulnerabilities:
|
||||
recent_vuln_data.append({
|
||||
'id': vuln.id,
|
||||
'project_name': vuln.scan.project.name if vuln.scan and vuln.scan.project else 'Unknown',
|
||||
'category': vuln.category.value,
|
||||
'severity': vuln.severity.value,
|
||||
'file_path': vuln.file_path,
|
||||
'message': vuln.message,
|
||||
'created_at': vuln.created_at.isoformat()
|
||||
})
|
||||
|
||||
# 项目漏洞统计
|
||||
project_stats = db.query(
|
||||
Project.name,
|
||||
func.count(Vulnerability.id).label('vulnerability_count')
|
||||
).join(Scan, Project.id == Scan.project_id).join(
|
||||
Vulnerability, Scan.id == Vulnerability.scan_id
|
||||
).group_by(Project.id, Project.name).order_by(
|
||||
desc('vulnerability_count')
|
||||
).limit(5).all()
|
||||
|
||||
project_vuln_data = []
|
||||
for project_name, count in project_stats:
|
||||
project_vuln_data.append({
|
||||
'project_name': project_name,
|
||||
'vulnerability_count': count
|
||||
})
|
||||
|
||||
return {
|
||||
'projects': total_projects,
|
||||
'scans': total_scans,
|
||||
'completed_scans': completed_scans,
|
||||
'vulnerabilities': total_vulnerabilities,
|
||||
'fixed': fixed_vulnerabilities,
|
||||
'severity_summary': severity_summary,
|
||||
'recent_vulnerabilities': recent_vuln_data,
|
||||
'project_stats': project_vuln_data
|
||||
}
|
||||
|
||||
async def get_trend_data(self, db: Session, days: int = 30) -> dict:
|
||||
"""获取趋势数据"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# 每日扫描统计
|
||||
daily_scans = db.query(
|
||||
func.date(Scan.created_at).label('date'),
|
||||
func.count(Scan.id).label('count')
|
||||
).filter(
|
||||
Scan.created_at >= start_date
|
||||
).group_by(func.date(Scan.created_at)).all()
|
||||
|
||||
# 每日漏洞统计
|
||||
daily_vulnerabilities = db.query(
|
||||
func.date(Vulnerability.created_at).label('date'),
|
||||
func.count(Vulnerability.id).label('count')
|
||||
).filter(
|
||||
Vulnerability.created_at >= start_date
|
||||
).group_by(func.date(Vulnerability.created_at)).all()
|
||||
|
||||
return {
|
||||
'daily_scans': [{'date': str(date), 'count': count} for date, count in daily_scans],
|
||||
'daily_vulnerabilities': [{'date': str(date), 'count': count} for date, count in daily_vulnerabilities]
|
||||
}
|
||||
|
||||
async def get_category_distribution(self, db: Session) -> dict:
|
||||
"""获取漏洞分类分布"""
|
||||
category_stats = db.query(
|
||||
Vulnerability.category,
|
||||
func.count(Vulnerability.id).label('count')
|
||||
).group_by(Vulnerability.category).all()
|
||||
|
||||
return {
|
||||
category.value: count for category, count in category_stats
|
||||
}
|
||||
@ -0,0 +1,162 @@
|
||||
"""
|
||||
扫描服务
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import SessionLocal
|
||||
from app.models.scan import Scan, ScanStatus
|
||||
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory
|
||||
from app.services.analyzer_service import AnalyzerService
|
||||
from app.services.report_service import ReportService
|
||||
import json
|
||||
|
||||
class ScanService:
|
||||
"""扫描服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.analyzer_service = AnalyzerService()
|
||||
self.report_service = ReportService()
|
||||
|
||||
async def run_scan(self, scan_id: int):
|
||||
"""运行扫描任务"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 获取扫描任务
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
print(f"扫描任务不存在: {scan_id}")
|
||||
return
|
||||
|
||||
# 更新扫描状态为运行中
|
||||
scan.status = ScanStatus.RUNNING
|
||||
scan.started_at = asyncio.get_event_loop().time()
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
# 获取项目信息
|
||||
project = scan.project
|
||||
if not project:
|
||||
raise Exception("项目不存在")
|
||||
|
||||
# 解析扫描配置
|
||||
scan_config = {}
|
||||
if scan.scan_config:
|
||||
scan_config = json.loads(scan.scan_config)
|
||||
|
||||
# 执行代码分析
|
||||
vulnerabilities_data = await self.analyzer_service.analyze_project(
|
||||
project_path=project.project_path,
|
||||
language=project.language,
|
||||
config=scan_config
|
||||
)
|
||||
|
||||
# 保存漏洞数据到数据库
|
||||
await self._save_vulnerabilities(db, scan, vulnerabilities_data)
|
||||
|
||||
# 更新扫描统计
|
||||
scan.total_files = scan_config.get('total_files', 100) # 模拟文件数
|
||||
scan.scanned_files = scan.total_files
|
||||
scan.total_vulnerabilities = len(vulnerabilities_data)
|
||||
scan.status = ScanStatus.COMPLETED
|
||||
scan.completed_at = asyncio.get_event_loop().time()
|
||||
|
||||
# 生成结果摘要
|
||||
summary = {
|
||||
'total_vulnerabilities': scan.total_vulnerabilities,
|
||||
'by_severity': {},
|
||||
'by_category': {}
|
||||
}
|
||||
|
||||
for vuln_data in vulnerabilities_data:
|
||||
severity = vuln_data.get('severity', 'medium')
|
||||
category = vuln_data.get('category', 'maintainability')
|
||||
|
||||
summary['by_severity'][severity] = summary['by_severity'].get(severity, 0) + 1
|
||||
summary['by_category'][category] = summary['by_category'].get(category, 0) + 1
|
||||
|
||||
scan.result_summary = json.dumps(summary)
|
||||
db.commit()
|
||||
|
||||
print(f"扫描完成: {scan_id}, 发现 {scan.total_vulnerabilities} 个漏洞")
|
||||
|
||||
except Exception as e:
|
||||
# 扫描失败
|
||||
scan.status = ScanStatus.FAILED
|
||||
scan.error_message = str(e)
|
||||
db.commit()
|
||||
print(f"扫描失败: {scan_id}, 错误: {str(e)}")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _save_vulnerabilities(self, db: Session, scan: Scan, vulnerabilities_data: List[dict]):
|
||||
"""保存漏洞数据到数据库"""
|
||||
for vuln_data in vulnerabilities_data:
|
||||
# 映射严重程度
|
||||
severity_mapping = {
|
||||
'critical': SeverityLevel.CRITICAL,
|
||||
'high': SeverityLevel.HIGH,
|
||||
'medium': SeverityLevel.MEDIUM,
|
||||
'low': SeverityLevel.LOW,
|
||||
'info': SeverityLevel.INFO
|
||||
}
|
||||
|
||||
# 映射分类
|
||||
category_mapping = {
|
||||
'security': VulnerabilityCategory.SECURITY,
|
||||
'performance': VulnerabilityCategory.PERFORMANCE,
|
||||
'maintainability': VulnerabilityCategory.MAINTAINABILITY,
|
||||
'reliability': VulnerabilityCategory.RELIABILITY,
|
||||
'usability': VulnerabilityCategory.USABILITY
|
||||
}
|
||||
|
||||
vulnerability = Vulnerability(
|
||||
scan_id=scan.id,
|
||||
rule_id=vuln_data.get('rule_id', 'unknown'),
|
||||
message=vuln_data.get('message', ''),
|
||||
category=category_mapping.get(vuln_data.get('category', 'maintainability'), VulnerabilityCategory.MAINTAINABILITY),
|
||||
severity=severity_mapping.get(vuln_data.get('severity', 'medium'), SeverityLevel.MEDIUM),
|
||||
file_path=vuln_data.get('file_path', ''),
|
||||
line_number=vuln_data.get('line_number'),
|
||||
column_number=vuln_data.get('column_number'),
|
||||
end_line=vuln_data.get('end_line'),
|
||||
end_column=vuln_data.get('end_column'),
|
||||
code_snippet=vuln_data.get('code_snippet', ''),
|
||||
context_before=vuln_data.get('context_before', ''),
|
||||
context_after=vuln_data.get('context_after', ''),
|
||||
ai_enhanced=vuln_data.get('ai_enhanced', False),
|
||||
ai_confidence=vuln_data.get('ai_confidence'),
|
||||
ai_suggestion=vuln_data.get('ai_suggestion', ''),
|
||||
status=VulnerabilityStatus.OPEN
|
||||
)
|
||||
|
||||
db.add(vulnerability)
|
||||
|
||||
db.commit()
|
||||
|
||||
def get_scan_progress(self, scan_id: int) -> dict:
|
||||
"""获取扫描进度"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
return {'error': '扫描任务不存在'}
|
||||
|
||||
progress = 0
|
||||
if scan.total_files > 0:
|
||||
progress = (scan.scanned_files / scan.total_files) * 100
|
||||
|
||||
return {
|
||||
'scan_id': scan_id,
|
||||
'status': scan.status.value,
|
||||
'progress': progress,
|
||||
'total_files': scan.total_files,
|
||||
'scanned_files': scan.scanned_files,
|
||||
'total_vulnerabilities': scan.total_vulnerabilities,
|
||||
'started_at': scan.started_at,
|
||||
'completed_at': scan.completed_at,
|
||||
'error_message': scan.error_message
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
@ -0,0 +1,247 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>代码扫描报告 - {{ project.name }}</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
line-height: 1.6;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
border-bottom: 2px solid #1890ff;
|
||||
padding-bottom: 20px;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.header h1 {
|
||||
color: #1890ff;
|
||||
margin: 0;
|
||||
font-size: 2.5em;
|
||||
}
|
||||
.header p {
|
||||
color: #666;
|
||||
margin: 10px 0 0 0;
|
||||
}
|
||||
.summary {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 20px;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.summary-card {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
border-left: 4px solid #1890ff;
|
||||
}
|
||||
.summary-card h3 {
|
||||
margin: 0 0 10px 0;
|
||||
color: #333;
|
||||
}
|
||||
.summary-card .number {
|
||||
font-size: 2em;
|
||||
font-weight: bold;
|
||||
color: #1890ff;
|
||||
}
|
||||
.severity-critical { border-left-color: #ff4d4f; }
|
||||
.severity-high { border-left-color: #ff7a45; }
|
||||
.severity-medium { border-left-color: #ffa940; }
|
||||
.severity-low { border-left-color: #73d13d; }
|
||||
.severity-info { border-left-color: #40a9ff; }
|
||||
|
||||
.section {
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.section h2 {
|
||||
color: #333;
|
||||
border-bottom: 2px solid #f0f0f0;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
.vulnerability {
|
||||
background: #fff;
|
||||
border: 1px solid #e8e8e8;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.vulnerability-header {
|
||||
background: #f8f9fa;
|
||||
padding: 15px 20px;
|
||||
border-bottom: 1px solid #e8e8e8;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
.vulnerability-title {
|
||||
font-weight: bold;
|
||||
font-size: 1.1em;
|
||||
}
|
||||
.severity-badge {
|
||||
padding: 4px 12px;
|
||||
border-radius: 20px;
|
||||
color: white;
|
||||
font-size: 0.9em;
|
||||
font-weight: bold;
|
||||
}
|
||||
.severity-critical { background: #ff4d4f; }
|
||||
.severity-high { background: #ff7a45; }
|
||||
.severity-medium { background: #ffa940; }
|
||||
.severity-low { background: #73d13d; }
|
||||
.severity-info { background: #40a9ff; }
|
||||
|
||||
.vulnerability-body {
|
||||
padding: 20px;
|
||||
}
|
||||
.vulnerability-meta {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.meta-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.meta-label {
|
||||
font-weight: bold;
|
||||
margin-right: 10px;
|
||||
min-width: 80px;
|
||||
}
|
||||
.file-path {
|
||||
font-family: 'Courier New', monospace;
|
||||
background: #f5f5f5;
|
||||
padding: 2px 6px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.code-block {
|
||||
background: #f8f8f8;
|
||||
border: 1px solid #e8e8e8;
|
||||
border-radius: 4px;
|
||||
padding: 15px;
|
||||
margin: 10px 0;
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 0.9em;
|
||||
overflow-x: auto;
|
||||
}
|
||||
.ai-suggestion {
|
||||
background: #e6f7ff;
|
||||
border: 1px solid #91d5ff;
|
||||
border-radius: 4px;
|
||||
padding: 15px;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.ai-suggestion h4 {
|
||||
margin: 0 0 10px 0;
|
||||
color: #1890ff;
|
||||
}
|
||||
.footer {
|
||||
text-align: center;
|
||||
margin-top: 40px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #e8e8e8;
|
||||
color: #666;
|
||||
}
|
||||
@media print {
|
||||
body { background: white; }
|
||||
.container { box-shadow: none; }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>代码扫描报告</h1>
|
||||
<p>项目: {{ project.name }} | 生成时间: {{ generated_at }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 扫描摘要 -->
|
||||
<div class="section">
|
||||
<h2>扫描摘要</h2>
|
||||
<div class="summary">
|
||||
<div class="summary-card">
|
||||
<h3>总漏洞数</h3>
|
||||
<div class="number">{{ total_vulnerabilities }}</div>
|
||||
</div>
|
||||
{% for severity, vulns in by_severity.items() %}
|
||||
<div class="summary-card severity-{{ severity }}">
|
||||
<h3>{{ severity|title }} 漏洞</h3>
|
||||
<div class="number">{{ vulns|length }}</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 漏洞详情 -->
|
||||
<div class="section">
|
||||
<h2>漏洞详情</h2>
|
||||
{% for vulnerability in vulnerabilities %}
|
||||
<div class="vulnerability">
|
||||
<div class="vulnerability-header">
|
||||
<div class="vulnerability-title">
|
||||
{{ vulnerability.rule_id }}: {{ vulnerability.message }}
|
||||
</div>
|
||||
<span class="severity-badge severity-{{ vulnerability.severity.value }}">
|
||||
{{ vulnerability.severity.value|upper }}
|
||||
</span>
|
||||
</div>
|
||||
<div class="vulnerability-body">
|
||||
<div class="vulnerability-meta">
|
||||
<div class="meta-item">
|
||||
<span class="meta-label">文件:</span>
|
||||
<span class="file-path">{{ vulnerability.file_path }}</span>
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<span class="meta-label">行号:</span>
|
||||
<span>{{ vulnerability.line_number or 'N/A' }}</span>
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<span class="meta-label">分类:</span>
|
||||
<span>{{ vulnerability.category.value }}</span>
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<span class="meta-label">状态:</span>
|
||||
<span>{{ vulnerability.status.value }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{% if vulnerability.code_snippet %}
|
||||
<div>
|
||||
<strong>相关代码:</strong>
|
||||
<div class="code-block">{{ vulnerability.code_snippet }}</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
{% if vulnerability.ai_enhanced and vulnerability.ai_suggestion %}
|
||||
<div class="ai-suggestion">
|
||||
<h4>🤖 AI 建议</h4>
|
||||
<p>{{ vulnerability.ai_suggestion }}</p>
|
||||
{% if vulnerability.ai_confidence %}
|
||||
<small>置信度: {{ (vulnerability.ai_confidence * 100)|round(1) }}%</small>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>此报告由代码漏洞检测系统自动生成</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@ -0,0 +1 @@
|
||||
# 工具类包
|
||||
@ -0,0 +1,33 @@
|
||||
"""
|
||||
配置管理
|
||||
"""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
class Config:
|
||||
"""配置类"""
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./code_scanner.db")
|
||||
|
||||
# AI服务配置
|
||||
DEEPSEEK_API_URL: str = os.getenv("DEEPSEEK_API_URL", "https://api.deepseek.com/v1/chat/completions")
|
||||
DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY", "your_deepseek_api_key_here")
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_FOLDER: str = os.getenv("UPLOAD_FOLDER", "uploads")
|
||||
MAX_CONTENT_LENGTH: int = int(os.getenv("MAX_CONTENT_LENGTH", "16 * 1024 * 1024")) # 16MB
|
||||
|
||||
# 扫描配置
|
||||
MAX_SCAN_FILES: int = int(os.getenv("MAX_SCAN_FILES", "1000"))
|
||||
SCAN_TIMEOUT: int = int(os.getenv("SCAN_TIMEOUT", "300")) # 5分钟
|
||||
|
||||
# 报告配置
|
||||
REPORTS_FOLDER: str = os.getenv("REPORTS_FOLDER", "reports")
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FILE: str = os.getenv("LOG_FILE", "app.log")
|
||||
|
||||
# 创建配置实例
|
||||
config = Config()
|
||||
@ -0,0 +1,55 @@
|
||||
"""
|
||||
代码漏洞检测系统 - 后端主启动文件
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api import projects, scans, reports, vulnerabilities, files
|
||||
from app.database import init_db
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="代码漏洞检测系统",
|
||||
description="基于AI增强的代码漏洞检测和报告生成系统",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000"], # 前端地址
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(projects.router, prefix="/api/projects", tags=["projects"])
|
||||
app.include_router(scans.router, prefix="/api/scans", tags=["scans"])
|
||||
app.include_router(reports.router, prefix="/api/reports", tags=["reports"])
|
||||
app.include_router(vulnerabilities.router, prefix="/api/vulnerabilities", tags=["vulnerabilities"])
|
||||
app.include_router(files.router, prefix="/api", tags=["files"])
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动时初始化数据库"""
|
||||
init_db()
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径健康检查"""
|
||||
return {"message": "代码漏洞检测系统 API", "status": "running"}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True,
|
||||
log_level="info"
|
||||
)
|
||||
@ -0,0 +1,12 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
sqlalchemy==2.0.23
|
||||
pydantic==2.5.0
|
||||
python-multipart==0.0.6
|
||||
requests==2.31.0
|
||||
python-dotenv==1.0.0
|
||||
alembic==1.12.1
|
||||
pandas==2.1.4
|
||||
jinja2==3.1.2
|
||||
weasyprint==60.2
|
||||
openpyxl==3.1.2
|
||||
Loading…
Reference in new issue