""" 漏洞管理API路由 """ from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from typing import List, Optional from app.database import get_db from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory from app.schemas.vulnerability import VulnerabilityResponse, VulnerabilityUpdate router = APIRouter() @router.get("/", response_model=List[VulnerabilityResponse]) async def get_vulnerabilities( scan_id: Optional[int] = None, project_id: Optional[int] = None, severity: Optional[SeverityLevel] = None, category: Optional[VulnerabilityCategory] = None, status: Optional[VulnerabilityStatus] = None, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): """获取漏洞列表""" query = db.query(Vulnerability) # 应用过滤条件 if scan_id: query = query.filter(Vulnerability.scan_id == scan_id) if severity: query = query.filter(Vulnerability.severity == severity) if category: query = query.filter(Vulnerability.category == category) if status: query = query.filter(Vulnerability.status == status) vulnerabilities = query.order_by( Vulnerability.severity.desc(), Vulnerability.line_number ).offset(skip).limit(limit).all() return vulnerabilities @router.get("/{vulnerability_id}", response_model=VulnerabilityResponse) async def get_vulnerability( vulnerability_id: int, db: Session = Depends(get_db) ): """获取漏洞详情""" vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first() if not vulnerability: raise HTTPException(status_code=404, detail="漏洞不存在") return vulnerability @router.put("/{vulnerability_id}", response_model=VulnerabilityResponse) async def update_vulnerability( vulnerability_id: int, vulnerability_update: VulnerabilityUpdate, db: Session = Depends(get_db) ): """更新漏洞状态""" vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first() if not vulnerability: raise HTTPException(status_code=404, detail="漏洞不存在") update_data = vulnerability_update.dict(exclude_unset=True) for field, value in update_data.items(): setattr(vulnerability, field, value) db.commit() db.refresh(vulnerability) return vulnerability @router.get("/stats/summary") async def get_vulnerability_stats( scan_id: Optional[int] = None, project_id: Optional[int] = None, db: Session = Depends(get_db) ): """获取漏洞统计信息""" query = db.query(Vulnerability) if scan_id: query = query.filter(Vulnerability.scan_id == scan_id) elif project_id: # 通过项目ID获取最新的扫描ID from app.models.scan import Scan latest_scan = db.query(Scan).filter( Scan.project_id == project_id, Scan.status == "completed" ).order_by(Scan.completed_at.desc()).first() if latest_scan: query = query.filter(Vulnerability.scan_id == latest_scan.id) else: return {"total": 0, "by_severity": {}, "by_category": {}} vulnerabilities = query.all() # 统计信息 total = len(vulnerabilities) by_severity = {} by_category = {} for vuln in vulnerabilities: # 按严重程度统计 severity = vuln.severity.value by_severity[severity] = by_severity.get(severity, 0) + 1 # 按分类统计 category = vuln.category.value by_category[category] = by_category.get(category, 0) + 1 return { "total": total, "by_severity": by_severity, "by_category": by_category }