You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
3.8 KiB

"""
漏洞管理API路由
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from typing import List, Optional
from app.database import get_db
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory
from app.schemas.vulnerability import VulnerabilityResponse, VulnerabilityUpdate
router = APIRouter()
@router.get("/", response_model=List[VulnerabilityResponse])
async def get_vulnerabilities(
scan_id: Optional[int] = None,
project_id: Optional[int] = None,
severity: Optional[SeverityLevel] = None,
category: Optional[VulnerabilityCategory] = None,
status: Optional[VulnerabilityStatus] = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
"""获取漏洞列表"""
query = db.query(Vulnerability)
# 应用过滤条件
if scan_id:
query = query.filter(Vulnerability.scan_id == scan_id)
if severity:
query = query.filter(Vulnerability.severity == severity)
if category:
query = query.filter(Vulnerability.category == category)
if status:
query = query.filter(Vulnerability.status == status)
vulnerabilities = query.order_by(
Vulnerability.severity.desc(),
Vulnerability.line_number
).offset(skip).limit(limit).all()
return vulnerabilities
@router.get("/{vulnerability_id}", response_model=VulnerabilityResponse)
async def get_vulnerability(
vulnerability_id: int,
db: Session = Depends(get_db)
):
"""获取漏洞详情"""
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
if not vulnerability:
raise HTTPException(status_code=404, detail="漏洞不存在")
return vulnerability
@router.put("/{vulnerability_id}", response_model=VulnerabilityResponse)
async def update_vulnerability(
vulnerability_id: int,
vulnerability_update: VulnerabilityUpdate,
db: Session = Depends(get_db)
):
"""更新漏洞状态"""
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
if not vulnerability:
raise HTTPException(status_code=404, detail="漏洞不存在")
update_data = vulnerability_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(vulnerability, field, value)
db.commit()
db.refresh(vulnerability)
return vulnerability
@router.get("/stats/summary")
async def get_vulnerability_stats(
scan_id: Optional[int] = None,
project_id: Optional[int] = None,
db: Session = Depends(get_db)
):
"""获取漏洞统计信息"""
query = db.query(Vulnerability)
if scan_id:
query = query.filter(Vulnerability.scan_id == scan_id)
elif project_id:
# 通过项目ID获取最新的扫描ID
from app.models.scan import Scan
latest_scan = db.query(Scan).filter(
Scan.project_id == project_id,
Scan.status == "completed"
).order_by(Scan.completed_at.desc()).first()
if latest_scan:
query = query.filter(Vulnerability.scan_id == latest_scan.id)
else:
return {"total": 0, "by_severity": {}, "by_category": {}}
vulnerabilities = query.all()
# 统计信息
total = len(vulnerabilities)
by_severity = {}
by_category = {}
for vuln in vulnerabilities:
# 按严重程度统计
severity = vuln.severity.value
by_severity[severity] = by_severity.get(severity, 0) + 1
# 按分类统计
category = vuln.category.value
by_category[category] = by_category.get(category, 0) + 1
return {
"total": total,
"by_severity": by_severity,
"by_category": by_category
}