You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
3.8 KiB
119 lines
3.8 KiB
"""
|
|
漏洞管理API路由
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
from app.database import get_db
|
|
from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory
|
|
from app.schemas.vulnerability import VulnerabilityResponse, VulnerabilityUpdate
|
|
|
|
router = APIRouter()
|
|
|
|
@router.get("/", response_model=List[VulnerabilityResponse])
|
|
async def get_vulnerabilities(
|
|
scan_id: Optional[int] = None,
|
|
project_id: Optional[int] = None,
|
|
severity: Optional[SeverityLevel] = None,
|
|
category: Optional[VulnerabilityCategory] = None,
|
|
status: Optional[VulnerabilityStatus] = None,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取漏洞列表"""
|
|
query = db.query(Vulnerability)
|
|
|
|
# 应用过滤条件
|
|
if scan_id:
|
|
query = query.filter(Vulnerability.scan_id == scan_id)
|
|
if severity:
|
|
query = query.filter(Vulnerability.severity == severity)
|
|
if category:
|
|
query = query.filter(Vulnerability.category == category)
|
|
if status:
|
|
query = query.filter(Vulnerability.status == status)
|
|
|
|
vulnerabilities = query.order_by(
|
|
Vulnerability.severity.desc(),
|
|
Vulnerability.line_number
|
|
).offset(skip).limit(limit).all()
|
|
|
|
return vulnerabilities
|
|
|
|
@router.get("/{vulnerability_id}", response_model=VulnerabilityResponse)
|
|
async def get_vulnerability(
|
|
vulnerability_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取漏洞详情"""
|
|
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
|
|
if not vulnerability:
|
|
raise HTTPException(status_code=404, detail="漏洞不存在")
|
|
return vulnerability
|
|
|
|
@router.put("/{vulnerability_id}", response_model=VulnerabilityResponse)
|
|
async def update_vulnerability(
|
|
vulnerability_id: int,
|
|
vulnerability_update: VulnerabilityUpdate,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""更新漏洞状态"""
|
|
vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first()
|
|
if not vulnerability:
|
|
raise HTTPException(status_code=404, detail="漏洞不存在")
|
|
|
|
update_data = vulnerability_update.dict(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(vulnerability, field, value)
|
|
|
|
db.commit()
|
|
db.refresh(vulnerability)
|
|
return vulnerability
|
|
|
|
@router.get("/stats/summary")
|
|
async def get_vulnerability_stats(
|
|
scan_id: Optional[int] = None,
|
|
project_id: Optional[int] = None,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取漏洞统计信息"""
|
|
query = db.query(Vulnerability)
|
|
|
|
if scan_id:
|
|
query = query.filter(Vulnerability.scan_id == scan_id)
|
|
elif project_id:
|
|
# 通过项目ID获取最新的扫描ID
|
|
from app.models.scan import Scan
|
|
latest_scan = db.query(Scan).filter(
|
|
Scan.project_id == project_id,
|
|
Scan.status == "completed"
|
|
).order_by(Scan.completed_at.desc()).first()
|
|
|
|
if latest_scan:
|
|
query = query.filter(Vulnerability.scan_id == latest_scan.id)
|
|
else:
|
|
return {"total": 0, "by_severity": {}, "by_category": {}}
|
|
|
|
vulnerabilities = query.all()
|
|
|
|
# 统计信息
|
|
total = len(vulnerabilities)
|
|
by_severity = {}
|
|
by_category = {}
|
|
|
|
for vuln in vulnerabilities:
|
|
# 按严重程度统计
|
|
severity = vuln.severity.value
|
|
by_severity[severity] = by_severity.get(severity, 0) + 1
|
|
|
|
# 按分类统计
|
|
category = vuln.category.value
|
|
by_category[category] = by_category.get(category, 0) + 1
|
|
|
|
return {
|
|
"total": total,
|
|
"by_severity": by_severity,
|
|
"by_category": by_category
|
|
}
|