diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..6d707e7 --- /dev/null +++ b/backend/app/__init__.py @@ -0,0 +1 @@ +# 应用包初始化文件 diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000..74872bb --- /dev/null +++ b/backend/app/api/__init__.py @@ -0,0 +1 @@ +# API路由包 diff --git a/backend/app/api/files.py b/backend/app/api/files.py new file mode 100644 index 0000000..0ecc524 --- /dev/null +++ b/backend/app/api/files.py @@ -0,0 +1,245 @@ +""" +文件管理API路由 +""" +import os +import mimetypes +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File +from fastapi.responses import FileResponse, StreamingResponse +from sqlalchemy.orm import Session +from typing import List +from app.database import get_db +from app.models.project import Project +import zipfile +import tempfile +import shutil + +router = APIRouter() + +@router.get("/projects/{project_id}/files") +async def get_project_files( + project_id: int, + path: str = "", + db: Session = Depends(get_db) +): + """获取项目文件列表""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + if not project.project_path or not os.path.exists(project.project_path): + raise HTTPException(status_code=404, detail="项目路径不存在") + + target_path = os.path.join(project.project_path, path) if path else project.project_path + + if not os.path.exists(target_path): + raise HTTPException(status_code=404, detail="路径不存在") + + if not os.path.isdir(target_path): + raise HTTPException(status_code=400, detail="路径不是目录") + + files = [] + directories = [] + + try: + for item in os.listdir(target_path): + item_path = os.path.join(target_path, item) + + # 跳过隐藏文件和常见的非代码文件 + if item.startswith('.') or item in ['node_modules', '__pycache__', '.git', 'venv', 'env']: + continue + + item_info = { + 'name': item, + 'path': os.path.join(path, item).replace('\\', '/'), + 'is_directory': os.path.isdir(item_path), + 'size': os.path.getsize(item_path) if os.path.isfile(item_path) else 0, + 'modified': os.path.getmtime(item_path) + } + + if item_info['is_directory']: + directories.append(item_info) + else: + # 只显示代码文件 + if _is_code_file(item): + files.append(item_info) + + # 排序:目录在前,文件在后,按名称排序 + directories.sort(key=lambda x: x['name'].lower()) + files.sort(key=lambda x: x['name'].lower()) + + return { + 'files': directories + files, + 'current_path': path, + 'project_path': project.project_path + } + + except PermissionError: + raise HTTPException(status_code=403, detail="权限不足") + except Exception as e: + raise HTTPException(status_code=500, detail=f"读取文件列表失败: {str(e)}") + +@router.get("/projects/{project_id}/files/content") +async def get_file_content( + project_id: int, + file_path: str, + db: Session = Depends(get_db) +): + """获取文件内容""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + full_path = os.path.join(project.project_path, file_path) + + if not os.path.exists(full_path) or not os.path.isfile(full_path): + raise HTTPException(status_code=404, detail="文件不存在") + + if not _is_code_file(full_path): + raise HTTPException(status_code=400, detail="不支持的文件类型") + + try: + # 尝试不同的编码 + encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1'] + content = None + + for encoding in encodings: + try: + with open(full_path, 'r', encoding=encoding) as f: + content = f.read() + break + except UnicodeDecodeError: + continue + + if content is None: + raise HTTPException(status_code=400, detail="文件编码不支持") + + return { + 'content': content, + 'file_path': file_path, + 'size': len(content.encode('utf-8')), + 'lines': len(content.splitlines()) + } + + except PermissionError: + raise HTTPException(status_code=403, detail="权限不足") + except Exception as e: + raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}") + +@router.post("/projects/{project_id}/files/content") +async def save_file_content( + project_id: int, + file_path: str, + content: str, + db: Session = Depends(get_db) +): + """保存文件内容""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + full_path = os.path.join(project.project_path, file_path) + + if not os.path.exists(full_path) or not os.path.isfile(full_path): + raise HTTPException(status_code=404, detail="文件不存在") + + try: + # 备份原文件 + backup_path = f"{full_path}.backup" + shutil.copy2(full_path, backup_path) + + # 写入新内容 + with open(full_path, 'w', encoding='utf-8') as f: + f.write(content) + + return {"message": "文件保存成功", "backup": backup_path} + + except PermissionError: + raise HTTPException(status_code=403, detail="权限不足") + except Exception as e: + raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}") + +@router.post("/projects/{project_id}/files/upload") +async def upload_files( + project_id: int, + files: List[UploadFile] = File(...), + db: Session = Depends(get_db) +): + """上传文件到项目""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + uploaded_files = [] + + try: + for file in files: + if not _is_code_file(file.filename): + continue + + file_path = os.path.join(project.project_path, file.filename) + + with open(file_path, "wb") as buffer: + content = await file.read() + buffer.write(content) + + uploaded_files.append({ + 'filename': file.filename, + 'size': len(content) + }) + + return {"message": f"成功上传 {len(uploaded_files)} 个文件", "files": uploaded_files} + + except Exception as e: + raise HTTPException(status_code=500, detail=f"上传文件失败: {str(e)}") + +@router.get("/projects/{project_id}/files/download") +async def download_project( + project_id: int, + db: Session = Depends(get_db) +): + """下载整个项目为ZIP文件""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + if not os.path.exists(project.project_path): + raise HTTPException(status_code=404, detail="项目路径不存在") + + try: + # 创建临时ZIP文件 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') + + with zipfile.ZipFile(temp_file.name, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, dirs, files in os.walk(project.project_path): + # 跳过不必要的目录 + dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__']] + + for file in files: + if _is_code_file(file): + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, project.project_path) + zipf.write(file_path, arcname) + + return FileResponse( + temp_file.name, + media_type='application/zip', + filename=f"{project.name}.zip" + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"创建下载文件失败: {str(e)}") + +def _is_code_file(filename: str) -> bool: + """判断是否为代码文件""" + code_extensions = { + '.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp', + '.cs', '.php', '.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.sh', + '.sql', '.html', '.css', '.scss', '.sass', '.less', '.vue', '.json', + '.xml', '.yaml', '.yml', '.md', '.txt' + } + + if isinstance(filename, str): + _, ext = os.path.splitext(filename.lower()) + return ext in code_extensions + + return False diff --git a/backend/app/api/projects.py b/backend/app/api/projects.py new file mode 100644 index 0000000..abfcce9 --- /dev/null +++ b/backend/app/api/projects.py @@ -0,0 +1,77 @@ +""" +项目管理API路由 +""" +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import List +from app.database import get_db +from app.models.project import Project +from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse + +router = APIRouter() + +@router.get("/", response_model=List[ProjectResponse]) +async def get_projects( + skip: int = 0, + limit: int = 100, + db: Session = Depends(get_db) +): + """获取项目列表""" + projects = db.query(Project).filter(Project.is_active == True).offset(skip).limit(limit).all() + return projects + +@router.post("/", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED) +async def create_project( + project: ProjectCreate, + db: Session = Depends(get_db) +): + """创建新项目""" + db_project = Project(**project.dict()) + db.add(db_project) + db.commit() + db.refresh(db_project) + return db_project + +@router.get("/{project_id}", response_model=ProjectResponse) +async def get_project( + project_id: int, + db: Session = Depends(get_db) +): + """获取项目详情""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + return project + +@router.put("/{project_id}", response_model=ProjectResponse) +async def update_project( + project_id: int, + project_update: ProjectUpdate, + db: Session = Depends(get_db) +): + """更新项目信息""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + update_data = project_update.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(project, field, value) + + db.commit() + db.refresh(project) + return project + +@router.delete("/{project_id}") +async def delete_project( + project_id: int, + db: Session = Depends(get_db) +): + """删除项目(软删除)""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + project.is_active = False + db.commit() + return {"message": "项目已删除"} diff --git a/backend/app/api/reports.py b/backend/app/api/reports.py new file mode 100644 index 0000000..688300f --- /dev/null +++ b/backend/app/api/reports.py @@ -0,0 +1,100 @@ +""" +报告生成API路由 +""" +from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi.responses import FileResponse +from sqlalchemy.orm import Session +from typing import Optional +from app.database import get_db +from app.models.scan import Scan +from app.models.project import Project +from app.services.report_service import ReportService + +router = APIRouter() + +@router.get("/scan/{scan_id}") +async def generate_scan_report( + scan_id: int, + format: str = "html", # html, pdf, json, excel + db: Session = Depends(get_db) +): + """生成扫描报告""" + scan = db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + raise HTTPException(status_code=404, detail="扫描任务不存在") + + if scan.status.value != "completed": + raise HTTPException(status_code=400, detail="扫描未完成,无法生成报告") + + report_service = ReportService() + + if format == "html": + report_path = await report_service.generate_html_report(scan) + return FileResponse( + report_path, + media_type="text/html", + filename=f"scan_report_{scan_id}.html" + ) + elif format == "pdf": + report_path = await report_service.generate_pdf_report(scan) + return FileResponse( + report_path, + media_type="application/pdf", + filename=f"scan_report_{scan_id}.pdf" + ) + elif format == "json": + report_data = await report_service.generate_json_report(scan) + return report_data + elif format == "excel": + report_path = await report_service.generate_excel_report(scan) + return FileResponse( + report_path, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + filename=f"scan_report_{scan_id}.xlsx" + ) + else: + raise HTTPException(status_code=400, detail="不支持的报告格式") + +@router.get("/project/{project_id}") +async def generate_project_report( + project_id: int, + format: str = "html", + db: Session = Depends(get_db) +): + """生成项目汇总报告""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + # 获取项目最新的扫描结果 + latest_scan = db.query(Scan).filter( + Scan.project_id == project_id, + Scan.status == "completed" + ).order_by(Scan.completed_at.desc()).first() + + if not latest_scan: + raise HTTPException(status_code=404, detail="项目没有完成的扫描记录") + + report_service = ReportService() + + if format == "html": + report_path = await report_service.generate_project_html_report(project, latest_scan) + return FileResponse( + report_path, + media_type="text/html", + filename=f"project_report_{project_id}.html" + ) + elif format == "json": + report_data = await report_service.generate_project_json_report(project, latest_scan) + return report_data + else: + raise HTTPException(status_code=400, detail="项目报告暂不支持此格式") + +@router.get("/dashboard/summary") +async def get_dashboard_summary(db: Session = Depends(get_db)): + """获取仪表板汇总数据""" + from app.services.dashboard_service import DashboardService + + dashboard_service = DashboardService() + summary = await dashboard_service.get_summary_data(db) + return summary diff --git a/backend/app/api/scans.py b/backend/app/api/scans.py new file mode 100644 index 0000000..17e5f30 --- /dev/null +++ b/backend/app/api/scans.py @@ -0,0 +1,108 @@ +""" +扫描管理API路由 +""" +from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks +from sqlalchemy.orm import Session +from typing import List +from app.database import get_db +from app.models.scan import Scan, ScanStatus, ScanType +from app.models.project import Project +from app.schemas.scan import ScanCreate, ScanResponse, ScanStatusResponse +from app.services.scan_service import ScanService + +router = APIRouter() + +@router.get("/", response_model=List[ScanResponse]) +async def get_scans( + project_id: int = None, + skip: int = 0, + limit: int = 100, + db: Session = Depends(get_db) +): + """获取扫描历史""" + query = db.query(Scan) + if project_id: + query = query.filter(Scan.project_id == project_id) + + scans = query.order_by(Scan.created_at.desc()).offset(skip).limit(limit).all() + return scans + +@router.post("/", response_model=ScanResponse, status_code=status.HTTP_201_CREATED) +async def create_scan( + scan_data: ScanCreate, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db) +): + """创建并启动扫描任务""" + # 验证项目存在 + project = db.query(Project).filter(Project.id == scan_data.project_id).first() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + # 创建扫描记录 + scan = Scan( + project_id=scan_data.project_id, + scan_type=scan_data.scan_type, + scan_config=scan_data.scan_config, + status=ScanStatus.PENDING + ) + db.add(scan) + db.commit() + db.refresh(scan) + + # 启动后台扫描任务 + background_tasks.add_task(ScanService.run_scan, scan.id) + + return scan + +@router.get("/{scan_id}", response_model=ScanResponse) +async def get_scan( + scan_id: int, + db: Session = Depends(get_db) +): + """获取扫描详情""" + scan = db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + raise HTTPException(status_code=404, detail="扫描任务不存在") + return scan + +@router.get("/{scan_id}/status", response_model=ScanStatusResponse) +async def get_scan_status( + scan_id: int, + db: Session = Depends(get_db) +): + """获取扫描状态""" + scan = db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + raise HTTPException(status_code=404, detail="扫描任务不存在") + + return { + "scan_id": scan.id, + "status": scan.status.value, + "progress": { + "total_files": scan.total_files, + "scanned_files": scan.scanned_files, + "percentage": (scan.scanned_files / scan.total_files * 100) if scan.total_files > 0 else 0 + }, + "started_at": scan.started_at, + "completed_at": scan.completed_at, + "error_message": scan.error_message + } + +@router.post("/{scan_id}/cancel") +async def cancel_scan( + scan_id: int, + db: Session = Depends(get_db) +): + """取消扫描任务""" + scan = db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + raise HTTPException(status_code=404, detail="扫描任务不存在") + + if scan.status in [ScanStatus.COMPLETED, ScanStatus.FAILED, ScanStatus.CANCELLED]: + raise HTTPException(status_code=400, detail="扫描任务已完成或已取消") + + scan.status = ScanStatus.CANCELLED + db.commit() + + return {"message": "扫描任务已取消"} diff --git a/backend/app/api/vulnerabilities.py b/backend/app/api/vulnerabilities.py new file mode 100644 index 0000000..36d9059 --- /dev/null +++ b/backend/app/api/vulnerabilities.py @@ -0,0 +1,118 @@ +""" +漏洞管理API路由 +""" +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session +from typing import List, Optional +from app.database import get_db +from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory +from app.schemas.vulnerability import VulnerabilityResponse, VulnerabilityUpdate + +router = APIRouter() + +@router.get("/", response_model=List[VulnerabilityResponse]) +async def get_vulnerabilities( + scan_id: Optional[int] = None, + project_id: Optional[int] = None, + severity: Optional[SeverityLevel] = None, + category: Optional[VulnerabilityCategory] = None, + status: Optional[VulnerabilityStatus] = None, + skip: int = 0, + limit: int = 100, + db: Session = Depends(get_db) +): + """获取漏洞列表""" + query = db.query(Vulnerability) + + # 应用过滤条件 + if scan_id: + query = query.filter(Vulnerability.scan_id == scan_id) + if severity: + query = query.filter(Vulnerability.severity == severity) + if category: + query = query.filter(Vulnerability.category == category) + if status: + query = query.filter(Vulnerability.status == status) + + vulnerabilities = query.order_by( + Vulnerability.severity.desc(), + Vulnerability.line_number + ).offset(skip).limit(limit).all() + + return vulnerabilities + +@router.get("/{vulnerability_id}", response_model=VulnerabilityResponse) +async def get_vulnerability( + vulnerability_id: int, + db: Session = Depends(get_db) +): + """获取漏洞详情""" + vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first() + if not vulnerability: + raise HTTPException(status_code=404, detail="漏洞不存在") + return vulnerability + +@router.put("/{vulnerability_id}", response_model=VulnerabilityResponse) +async def update_vulnerability( + vulnerability_id: int, + vulnerability_update: VulnerabilityUpdate, + db: Session = Depends(get_db) +): + """更新漏洞状态""" + vulnerability = db.query(Vulnerability).filter(Vulnerability.id == vulnerability_id).first() + if not vulnerability: + raise HTTPException(status_code=404, detail="漏洞不存在") + + update_data = vulnerability_update.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(vulnerability, field, value) + + db.commit() + db.refresh(vulnerability) + return vulnerability + +@router.get("/stats/summary") +async def get_vulnerability_stats( + scan_id: Optional[int] = None, + project_id: Optional[int] = None, + db: Session = Depends(get_db) +): + """获取漏洞统计信息""" + query = db.query(Vulnerability) + + if scan_id: + query = query.filter(Vulnerability.scan_id == scan_id) + elif project_id: + # 通过项目ID获取最新的扫描ID + from app.models.scan import Scan + latest_scan = db.query(Scan).filter( + Scan.project_id == project_id, + Scan.status == "completed" + ).order_by(Scan.completed_at.desc()).first() + + if latest_scan: + query = query.filter(Vulnerability.scan_id == latest_scan.id) + else: + return {"total": 0, "by_severity": {}, "by_category": {}} + + vulnerabilities = query.all() + + # 统计信息 + total = len(vulnerabilities) + by_severity = {} + by_category = {} + + for vuln in vulnerabilities: + # 按严重程度统计 + severity = vuln.severity.value + by_severity[severity] = by_severity.get(severity, 0) + 1 + + # 按分类统计 + category = vuln.category.value + by_category[category] = by_category.get(category, 0) + 1 + + return { + "total": total, + "by_severity": by_severity, + "by_category": by_category + } diff --git a/backend/app/core/__init__.py b/backend/app/core/__init__.py new file mode 100644 index 0000000..f69f692 --- /dev/null +++ b/backend/app/core/__init__.py @@ -0,0 +1 @@ +# 核心模块包 diff --git a/backend/app/core/analyzers/__init__.py b/backend/app/core/analyzers/__init__.py new file mode 100644 index 0000000..3b30baa --- /dev/null +++ b/backend/app/core/analyzers/__init__.py @@ -0,0 +1 @@ +# 分析器包 diff --git a/backend/app/core/analyzers/base_analyzer.py b/backend/app/core/analyzers/base_analyzer.py new file mode 100644 index 0000000..0f6b13a --- /dev/null +++ b/backend/app/core/analyzers/base_analyzer.py @@ -0,0 +1,78 @@ +""" +分析器基类 +""" +from abc import ABC, abstractmethod +from typing import List, Dict, Any +import os +import glob + +class BaseAnalyzer(ABC): + """分析器基类""" + + def __init__(self): + self.name = "Base Analyzer" + self.version = "1.0.0" + self.supported_extensions = [] + self.description = "基础分析器" + + @abstractmethod + async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """ + 分析项目代码 + + Args: + project_path: 项目路径 + config: 分析配置 + + Returns: + 漏洞列表 + """ + pass + + def get_project_files(self, project_path: str) -> List[str]: + """获取项目中的所有文件""" + files = [] + for ext in self.supported_extensions: + pattern = os.path.join(project_path, "**", f"*.{ext}") + files.extend(glob.glob(pattern, recursive=True)) + return files + + def read_file_content(self, file_path: str) -> str: + """读取文件内容""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + except UnicodeDecodeError: + # 如果UTF-8解码失败,尝试其他编码 + try: + with open(file_path, 'r', encoding='gbk') as f: + return f.read() + except: + return "" + except Exception: + return "" + + def create_vulnerability( + self, + rule_id: str, + message: str, + file_path: str, + line_number: int = None, + severity: str = "medium", + category: str = "maintainability", + code_snippet: str = "", + context_before: str = "", + context_after: str = "" + ) -> Dict[str, Any]: + """创建漏洞对象""" + return { + 'rule_id': rule_id, + 'message': message, + 'file_path': file_path, + 'line_number': line_number, + 'severity': severity, + 'category': category, + 'code_snippet': code_snippet, + 'context_before': context_before, + 'context_after': context_after + } diff --git a/backend/app/core/analyzers/cpp_analyzer.py b/backend/app/core/analyzers/cpp_analyzer.py new file mode 100644 index 0000000..1d8cc79 --- /dev/null +++ b/backend/app/core/analyzers/cpp_analyzer.py @@ -0,0 +1,187 @@ +""" +C++代码分析器 +""" +import re +import os +from typing import List, Dict, Any +from .base_analyzer import BaseAnalyzer + +class CppAnalyzer(BaseAnalyzer): + """C++代码分析器""" + + def __init__(self): + super().__init__() + self.name = "C++ Analyzer" + self.version = "1.0.0" + self.supported_extensions = ["cpp", "cc", "cxx", "hpp", "h"] + self.description = "C++代码静态分析器" + + async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """分析C++代码""" + vulnerabilities = [] + + # 获取所有C++文件 + cpp_files = self.get_project_files(project_path) + + for file_path in cpp_files: + try: + # 读取文件内容 + content = self.read_file_content(file_path) + if not content: + continue + + lines = content.split('\n') + + # 执行各种检查 + vulnerabilities.extend(self._check_memory_issues(lines, file_path)) + vulnerabilities.extend(self._check_security_issues(lines, file_path)) + vulnerabilities.extend(self._check_performance_issues(lines, file_path)) + + except Exception as e: + # 分析错误 + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP000", + message=f"分析错误: {str(e)}", + file_path=file_path, + severity="high", + category="reliability" + )) + + return vulnerabilities + + def _check_memory_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]: + """检查内存相关问题""" + vulnerabilities = [] + + for i, line in enumerate(lines): + line_num = i + 1 + line_stripped = line.strip() + + # 检查裸指针使用 + if re.search(r'\*[a-zA-Z_][a-zA-Z0-9_]*\s*[=;]', line_stripped): + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP101", + message="使用裸指针,建议使用智能指针", + file_path=file_path, + line_number=line_num, + severity="medium", + category="reliability", + code_snippet=line_stripped + )) + + # 检查malloc/free使用 + elif 'malloc(' in line_stripped or 'free(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP102", + message="使用malloc/free,建议使用new/delete或智能指针", + file_path=file_path, + line_number=line_num, + severity="medium", + category="reliability", + code_snippet=line_stripped + )) + + # 检查数组越界风险 + elif re.search(r'\[[a-zA-Z_][a-zA-Z0-9_]*\]', line_stripped): + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP103", + message="数组访问未进行边界检查", + file_path=file_path, + line_number=line_num, + severity="high", + category="reliability", + code_snippet=line_stripped + )) + + return vulnerabilities + + def _check_security_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]: + """检查安全问题""" + vulnerabilities = [] + + for i, line in enumerate(lines): + line_num = i + 1 + line_stripped = line.strip() + + # 检查strcpy使用 + if 'strcpy(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP201", + message="使用strcpy(),存在缓冲区溢出风险", + file_path=file_path, + line_number=line_num, + severity="critical", + category="security", + code_snippet=line_stripped + )) + + # 检查sprintf使用 + elif 'sprintf(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP202", + message="使用sprintf(),存在缓冲区溢出风险", + file_path=file_path, + line_number=line_num, + severity="critical", + category="security", + code_snippet=line_stripped + )) + + # 检查gets使用 + elif 'gets(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP203", + message="使用gets(),存在缓冲区溢出风险", + file_path=file_path, + line_number=line_num, + severity="critical", + category="security", + code_snippet=line_stripped + )) + + # 检查system调用 + elif 'system(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP204", + message="使用system()调用,存在命令注入风险", + file_path=file_path, + line_number=line_num, + severity="high", + category="security", + code_snippet=line_stripped + )) + + return vulnerabilities + + def _check_performance_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]: + """检查性能问题""" + vulnerabilities = [] + + for i, line in enumerate(lines): + line_num = i + 1 + line_stripped = line.strip() + + # 检查循环中的字符串连接 + if re.search(r'for\s*\([^)]*\)\s*\{', line_stripped): + # 查找循环体中的字符串连接 + for j in range(i + 1, min(i + 20, len(lines))): # 检查循环体前20行 + loop_line = lines[j].strip() + if loop_line == '}': + break + if '+' in loop_line and ('"' in loop_line or "'" in loop_line): + vulnerabilities.append(self.create_vulnerability( + rule_id="CPP301", + message="循环中使用字符串连接,影响性能", + file_path=file_path, + line_number=j + 1, + severity="low", + category="performance", + code_snippet=loop_line + )) + + # 检查未使用的头文件包含 + elif line_stripped.startswith('#include'): + # 这里可以添加更复杂的检查逻辑 + pass + + return vulnerabilities diff --git a/backend/app/core/analyzers/javascript_analyzer.py b/backend/app/core/analyzers/javascript_analyzer.py new file mode 100644 index 0000000..e3e0cd0 --- /dev/null +++ b/backend/app/core/analyzers/javascript_analyzer.py @@ -0,0 +1,233 @@ +""" +JavaScript代码分析器 +""" +import re +import os +from typing import List, Dict, Any +from .base_analyzer import BaseAnalyzer + +class JavaScriptAnalyzer(BaseAnalyzer): + """JavaScript代码分析器""" + + def __init__(self): + super().__init__() + self.name = "JavaScript Analyzer" + self.version = "1.0.0" + self.supported_extensions = ["js", "jsx", "ts", "tsx"] + self.description = "JavaScript/TypeScript代码静态分析器" + + async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """分析JavaScript代码""" + vulnerabilities = [] + + # 获取所有JavaScript文件 + js_files = self.get_project_files(project_path) + + for file_path in js_files: + try: + # 读取文件内容 + content = self.read_file_content(file_path) + if not content: + continue + + lines = content.split('\n') + + # 执行各种检查 + vulnerabilities.extend(self._check_security_issues(lines, file_path)) + vulnerabilities.extend(self._check_performance_issues(lines, file_path)) + vulnerabilities.extend(self._check_maintainability_issues(lines, file_path)) + + except Exception as e: + # 分析错误 + vulnerabilities.append(self.create_vulnerability( + rule_id="JS000", + message=f"分析错误: {str(e)}", + file_path=file_path, + severity="high", + category="reliability" + )) + + return vulnerabilities + + def _check_security_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]: + """检查安全问题""" + vulnerabilities = [] + + for i, line in enumerate(lines): + line_num = i + 1 + line_stripped = line.strip() + + # 检查eval使用 + if 'eval(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="JS101", + message="使用eval(),存在代码注入风险", + file_path=file_path, + line_number=line_num, + severity="critical", + category="security", + code_snippet=line_stripped + )) + + # 检查innerHTML使用 + elif 'innerHTML' in line_stripped and '=' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="JS102", + message="使用innerHTML,存在XSS风险", + file_path=file_path, + line_number=line_num, + severity="high", + category="security", + code_snippet=line_stripped + )) + + # 检查document.write使用 + elif 'document.write(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="JS103", + message="使用document.write(),存在XSS风险", + file_path=file_path, + line_number=line_num, + severity="high", + category="security", + code_snippet=line_stripped + )) + + # 检查console.log在生产环境中的使用 + elif 'console.log(' in line_stripped: + vulnerabilities.append(self.create_vulnerability( + rule_id="JS104", + message="生产环境中不应使用console.log", + file_path=file_path, + line_number=line_num, + severity="low", + category="security", + code_snippet=line_stripped + )) + + # 检查硬编码的敏感信息 + elif re.search(r'password\s*[:=]\s*["\'][^"\']+["\']', line_stripped, re.IGNORECASE): + vulnerabilities.append(self.create_vulnerability( + rule_id="JS105", + message="代码中包含硬编码的密码", + file_path=file_path, + line_number=line_num, + severity="high", + category="security", + code_snippet=line_stripped + )) + + return vulnerabilities + + def _check_performance_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]: + """检查性能问题""" + vulnerabilities = [] + + for i, line in enumerate(lines): + line_num = i + 1 + line_stripped = line.strip() + + # 检查DOM操作在循环中 + if re.search(r'for\s*\([^)]*\)\s*\{', line_stripped): + # 查找循环体中的DOM操作 + for j in range(i + 1, min(i + 20, len(lines))): + loop_line = lines[j].strip() + if loop_line == '}': + break + if any(dom_op in loop_line for dom_op in ['getElementById', 'querySelector', 'appendChild']): + vulnerabilities.append(self.create_vulnerability( + rule_id="JS201", + message="循环中进行DOM操作,影响性能", + file_path=file_path, + line_number=j + 1, + severity="medium", + category="performance", + code_snippet=loop_line + )) + + # 检查未使用的变量声明 + elif line_stripped.startswith('var ') or line_stripped.startswith('let ') or line_stripped.startswith('const '): + var_name = re.search(r'(var|let|const)\s+([a-zA-Z_$][a-zA-Z0-9_$]*)', line_stripped) + if var_name: + var_name = var_name.group(2) + # 检查变量是否在后续代码中使用 + is_used = False + for k in range(i + 1, len(lines)): + if var_name in lines[k]: + is_used = True + break + + if not is_used: + vulnerabilities.append(self.create_vulnerability( + rule_id="JS202", + message=f"未使用的变量: {var_name}", + file_path=file_path, + line_number=line_num, + severity="low", + category="performance", + code_snippet=line_stripped + )) + + return vulnerabilities + + def _check_maintainability_issues(self, lines: List[str], file_path: str) -> List[Dict[str, Any]]: + """检查可维护性问题""" + vulnerabilities = [] + + for i, line in enumerate(lines): + line_num = i + 1 + line_stripped = line.strip() + + # 检查函数长度(简单检查) + if line_stripped.startswith('function ') or re.match(r'const\s+\w+\s*=\s*\([^)]*\)\s*=>', line_stripped): + # 计算函数体长度 + brace_count = 0 + function_start = i + for j in range(i, len(lines)): + line_content = lines[j] + brace_count += line_content.count('{') + brace_count -= line_content.count('}') + + if brace_count > 0 and j > i: + continue + elif brace_count == 0 and j > i: + function_length = j - function_start + if function_length > 30: + vulnerabilities.append(self.create_vulnerability( + rule_id="JS301", + message="函数过长,建议拆分", + file_path=file_path, + line_number=line_num, + severity="medium", + category="maintainability", + code_snippet=line_stripped + )) + break + + # 检查深度嵌套 + elif line_stripped.endswith('{'): + indent_level = len(line) - len(line.lstrip()) + if indent_level > 40: # 假设每层缩进4个空格 + vulnerabilities.append(self.create_vulnerability( + rule_id="JS302", + message="代码嵌套过深,影响可读性", + file_path=file_path, + line_number=line_num, + severity="low", + category="maintainability", + code_snippet=line_stripped + )) + + # 检查魔法数字 + elif re.search(r'\b\d{3,}\b', line_stripped): + vulnerabilities.append(self.create_vulnerability( + rule_id="JS303", + message="存在魔法数字,建议使用常量", + file_path=file_path, + line_number=line_num, + severity="low", + category="maintainability", + code_snippet=line_stripped + )) + + return vulnerabilities diff --git a/backend/app/core/analyzers/python_analyzer.py b/backend/app/core/analyzers/python_analyzer.py new file mode 100644 index 0000000..d0891d8 --- /dev/null +++ b/backend/app/core/analyzers/python_analyzer.py @@ -0,0 +1,193 @@ +""" +Python代码分析器 +""" +import ast +import os +from typing import List, Dict, Any +from .base_analyzer import BaseAnalyzer + +class PythonAnalyzer(BaseAnalyzer): + """Python代码分析器""" + + def __init__(self): + super().__init__() + self.name = "Python Analyzer" + self.version = "1.0.0" + self.supported_extensions = ["py"] + self.description = "Python代码静态分析器" + + async def analyze(self, project_path: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """分析Python代码""" + vulnerabilities = [] + + # 获取所有Python文件 + python_files = self.get_project_files(project_path) + + for file_path in python_files: + try: + # 读取文件内容 + content = self.read_file_content(file_path) + if not content: + continue + + # 解析AST + tree = ast.parse(content, filename=file_path) + + # 执行各种检查 + vulnerabilities.extend(self._check_security_issues(tree, file_path, content)) + vulnerabilities.extend(self._check_performance_issues(tree, file_path, content)) + vulnerabilities.extend(self._check_maintainability_issues(tree, file_path, content)) + + except SyntaxError as e: + # 语法错误 + vulnerabilities.append(self.create_vulnerability( + rule_id="PY001", + message=f"语法错误: {str(e)}", + file_path=file_path, + line_number=e.lineno, + severity="critical", + category="reliability" + )) + except Exception as e: + # 其他错误 + vulnerabilities.append(self.create_vulnerability( + rule_id="PY000", + message=f"分析错误: {str(e)}", + file_path=file_path, + severity="high", + category="reliability" + )) + + return vulnerabilities + + def _check_security_issues(self, tree: ast.AST, file_path: str, content: str) -> List[Dict[str, Any]]: + """检查安全问题""" + vulnerabilities = [] + lines = content.split('\n') + + for node in ast.walk(tree): + # 检查eval使用 + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'eval': + vulnerabilities.append(self.create_vulnerability( + rule_id="PY101", + message="使用了eval()函数,存在代码注入风险", + file_path=file_path, + line_number=node.lineno, + severity="critical", + category="security", + code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + )) + + # 检查exec使用 + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'exec': + vulnerabilities.append(self.create_vulnerability( + rule_id="PY102", + message="使用了exec()函数,存在代码注入风险", + file_path=file_path, + line_number=node.lineno, + severity="critical", + category="security", + code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + )) + + # 检查pickle使用 + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + if (isinstance(node.func.value, ast.Name) and + node.func.value.id == 'pickle' and + node.func.attr in ['loads', 'load']): + vulnerabilities.append(self.create_vulnerability( + rule_id="PY103", + message="使用了pickle反序列化,存在安全风险", + file_path=file_path, + line_number=node.lineno, + severity="high", + category="security", + code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + )) + + return vulnerabilities + + def _check_performance_issues(self, tree: ast.AST, file_path: str, content: str) -> List[Dict[str, Any]]: + """检查性能问题""" + vulnerabilities = [] + lines = content.split('\n') + + for node in ast.walk(tree): + # 检查列表推导式中的循环 + if isinstance(node, ast.ListComp): + if len(node.generators) > 1: + vulnerabilities.append(self.create_vulnerability( + rule_id="PY201", + message="复杂的列表推导式可能影响性能", + file_path=file_path, + line_number=node.lineno, + severity="medium", + category="performance", + code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + )) + + # 检查全局变量使用 + elif isinstance(node, ast.Global): + vulnerabilities.append(self.create_vulnerability( + rule_id="PY202", + message="使用全局变量可能影响性能", + file_path=file_path, + line_number=node.lineno, + severity="low", + category="performance", + code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + )) + + return vulnerabilities + + def _check_maintainability_issues(self, tree: ast.AST, file_path: str, content: str) -> List[Dict[str, Any]]: + """检查可维护性问题""" + vulnerabilities = [] + lines = content.split('\n') + + # 检查函数长度 + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + if len(node.body) > 20: + vulnerabilities.append(self.create_vulnerability( + rule_id="PY301", + message=f"函数 '{node.name}' 过长,建议拆分", + file_path=file_path, + line_number=node.lineno, + severity="medium", + category="maintainability", + code_snippet=f"def {node.name}(...):" + )) + + # 检查类长度 + elif isinstance(node, ast.ClassDef): + if len(node.body) > 15: + vulnerabilities.append(self.create_vulnerability( + rule_id="PY302", + message=f"类 '{node.name}' 过长,建议拆分", + file_path=file_path, + line_number=node.lineno, + severity="medium", + category="maintainability", + code_snippet=f"class {node.name}:" + )) + + # 检查循环嵌套 + elif isinstance(node, (ast.For, ast.While)): + nested_loops = 0 + for child in ast.walk(node): + if isinstance(child, (ast.For, ast.While)) and child != node: + nested_loops += 1 + + if nested_loops > 2: + vulnerabilities.append(self.create_vulnerability( + rule_id="PY303", + message="循环嵌套过深,影响代码可读性", + file_path=file_path, + line_number=node.lineno, + severity="low", + category="maintainability", + code_snippet=lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + )) + + return vulnerabilities diff --git a/backend/app/database/__init__.py b/backend/app/database/__init__.py new file mode 100644 index 0000000..8683a7d --- /dev/null +++ b/backend/app/database/__init__.py @@ -0,0 +1,35 @@ +""" +数据库配置和连接管理 +""" +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +import os + +# 数据库URL配置 +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./code_scanner.db") + +# 创建数据库引擎 +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {} +) + +# 创建会话工厂 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建基础模型类 +Base = declarative_base() + +def get_db(): + """获取数据库会话""" + db = SessionLocal() + try: + yield db + finally: + db.close() + +def init_db(): + """初始化数据库表""" + from app.models import project, scan, vulnerability + Base.metadata.create_all(bind=engine) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000..7713b10 --- /dev/null +++ b/backend/app/models/__init__.py @@ -0,0 +1,4 @@ +# 数据模型包 +from .project import Project +from .scan import Scan +from .vulnerability import Vulnerability diff --git a/backend/app/models/project.py b/backend/app/models/project.py new file mode 100644 index 0000000..d30c503 --- /dev/null +++ b/backend/app/models/project.py @@ -0,0 +1,32 @@ +""" +项目数据模型 +""" +from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +from app.database import Base + +class Project(Base): + """项目模型""" + __tablename__ = "projects" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String(100), nullable=False, index=True) + description = Column(Text) + language = Column(String(20), nullable=False) # Python, C++, JavaScript等 + repository_url = Column(String(500)) + project_path = Column(String(500)) # 本地项目路径 + config = Column(Text) # JSON格式的配置信息 + + # 状态字段 + is_active = Column(Boolean, default=True) + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + # 关联关系 + scans = relationship("Scan", back_populates="project", cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/backend/app/models/scan.py b/backend/app/models/scan.py new file mode 100644 index 0000000..e85b55a --- /dev/null +++ b/backend/app/models/scan.py @@ -0,0 +1,57 @@ +""" +扫描任务数据模型 +""" +from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, ForeignKey, Enum +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +import enum +from app.database import Base + +class ScanStatus(enum.Enum): + """扫描状态枚举""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + +class ScanType(enum.Enum): + """扫描类型枚举""" + FULL = "full" # 全量扫描 + INCREMENTAL = "incremental" # 增量扫描 + CUSTOM = "custom" # 自定义扫描 + +class Scan(Base): + """扫描任务模型""" + __tablename__ = "scans" + + id = Column(Integer, primary_key=True, index=True) + project_id = Column(Integer, ForeignKey("projects.id"), nullable=False) + + # 扫描配置 + scan_type = Column(Enum(ScanType), default=ScanType.FULL) + scan_config = Column(Text) # JSON格式的扫描配置 + + # 扫描状态 + status = Column(Enum(ScanStatus), default=ScanStatus.PENDING) + + # 扫描统计 + total_files = Column(Integer, default=0) + scanned_files = Column(Integer, default=0) + total_vulnerabilities = Column(Integer, default=0) + + # 时间戳 + started_at = Column(DateTime(timezone=True)) + completed_at = Column(DateTime(timezone=True)) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + # 结果信息 + result_summary = Column(Text) # JSON格式的结果摘要 + error_message = Column(Text) # 错误信息 + + # 关联关系 + project = relationship("Project", back_populates="scans") + vulnerabilities = relationship("Vulnerability", back_populates="scan", cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/backend/app/models/vulnerability.py b/backend/app/models/vulnerability.py new file mode 100644 index 0000000..c086c5d --- /dev/null +++ b/backend/app/models/vulnerability.py @@ -0,0 +1,77 @@ +""" +漏洞数据模型 +""" +from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, ForeignKey, Enum, Float +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +import enum +from app.database import Base + +class SeverityLevel(enum.Enum): + """严重程度枚举""" + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + INFO = "info" + +class VulnerabilityCategory(enum.Enum): + """漏洞分类枚举""" + SECURITY = "security" + PERFORMANCE = "performance" + MAINTAINABILITY = "maintainability" + RELIABILITY = "reliability" + USABILITY = "usability" + +class VulnerabilityStatus(enum.Enum): + """漏洞状态枚举""" + OPEN = "open" + FIXED = "fixed" + FALSE_POSITIVE = "false_positive" + WONT_FIX = "wont_fix" + +class Vulnerability(Base): + """漏洞模型""" + __tablename__ = "vulnerabilities" + + id = Column(Integer, primary_key=True, index=True) + scan_id = Column(Integer, ForeignKey("scans.id"), nullable=False) + + # 漏洞基本信息 + rule_id = Column(String(100), nullable=False) # 规则ID + message = Column(Text, nullable=False) # 漏洞描述 + category = Column(Enum(VulnerabilityCategory), nullable=False) + severity = Column(Enum(SeverityLevel), nullable=False) + + # 位置信息 + file_path = Column(String(500), nullable=False) + line_number = Column(Integer) + column_number = Column(Integer) + end_line = Column(Integer) + end_column = Column(Integer) + + # 代码上下文 + code_snippet = Column(Text) # 相关代码片段 + context_before = Column(Text) # 前置代码上下文 + context_after = Column(Text) # 后置代码上下文 + + # AI增强信息 + ai_enhanced = Column(Boolean, default=False) + ai_confidence = Column(Float) # AI置信度 0-1 + ai_suggestion = Column(Text) # AI修复建议 + + # 状态管理 + status = Column(Enum(VulnerabilityStatus), default=VulnerabilityStatus.OPEN) + assigned_to = Column(String(100)) # 分配给谁 + fix_commit = Column(String(100)) # 修复的提交哈希 + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + fixed_at = Column(DateTime(timezone=True)) + + # 关联关系 + scan = relationship("Scan", back_populates="vulnerabilities") + + def __repr__(self): + return f"" diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..2dc4e08 --- /dev/null +++ b/backend/app/schemas/__init__.py @@ -0,0 +1 @@ +# 数据模式包 diff --git a/backend/app/schemas/project.py b/backend/app/schemas/project.py new file mode 100644 index 0000000..2f1ba96 --- /dev/null +++ b/backend/app/schemas/project.py @@ -0,0 +1,37 @@ +""" +项目相关的Pydantic模式 +""" +from pydantic import BaseModel +from typing import Optional +from datetime import datetime + +class ProjectBase(BaseModel): + """项目基础模式""" + name: str + description: Optional[str] = None + language: str + repository_url: Optional[str] = None + project_path: Optional[str] = None + config: Optional[str] = None + +class ProjectCreate(ProjectBase): + """创建项目模式""" + pass + +class ProjectUpdate(BaseModel): + """更新项目模式""" + name: Optional[str] = None + description: Optional[str] = None + repository_url: Optional[str] = None + project_path: Optional[str] = None + config: Optional[str] = None + +class ProjectResponse(ProjectBase): + """项目响应模式""" + id: int + is_active: bool + created_at: datetime + updated_at: Optional[datetime] = None + + class Config: + from_attributes = True diff --git a/backend/app/schemas/scan.py b/backend/app/schemas/scan.py new file mode 100644 index 0000000..9176f24 --- /dev/null +++ b/backend/app/schemas/scan.py @@ -0,0 +1,42 @@ +""" +扫描相关的Pydantic模式 +""" +from pydantic import BaseModel +from typing import Optional +from datetime import datetime +from app.models.scan import ScanStatus, ScanType + +class ScanBase(BaseModel): + """扫描基础模式""" + project_id: int + scan_type: ScanType = ScanType.FULL + scan_config: Optional[str] = None + +class ScanCreate(ScanBase): + """创建扫描模式""" + pass + +class ScanResponse(ScanBase): + """扫描响应模式""" + id: int + status: ScanStatus + total_files: int + scanned_files: int + total_vulnerabilities: int + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + created_at: datetime + result_summary: Optional[str] = None + error_message: Optional[str] = None + + class Config: + from_attributes = True + +class ScanStatusResponse(BaseModel): + """扫描状态响应模式""" + scan_id: int + status: str + progress: dict + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + error_message: Optional[str] = None diff --git a/backend/app/schemas/vulnerability.py b/backend/app/schemas/vulnerability.py new file mode 100644 index 0000000..4f89d6d --- /dev/null +++ b/backend/app/schemas/vulnerability.py @@ -0,0 +1,49 @@ +""" +漏洞相关的Pydantic模式 +""" +from pydantic import BaseModel +from typing import Optional +from datetime import datetime +from app.models.vulnerability import SeverityLevel, VulnerabilityCategory, VulnerabilityStatus + +class VulnerabilityBase(BaseModel): + """漏洞基础模式""" + rule_id: str + message: str + category: VulnerabilityCategory + severity: SeverityLevel + file_path: str + line_number: Optional[int] = None + column_number: Optional[int] = None + end_line: Optional[int] = None + end_column: Optional[int] = None + code_snippet: Optional[str] = None + context_before: Optional[str] = None + context_after: Optional[str] = None + ai_enhanced: bool = False + ai_confidence: Optional[float] = None + ai_suggestion: Optional[str] = None + +class VulnerabilityCreate(VulnerabilityBase): + """创建漏洞模式""" + scan_id: int + +class VulnerabilityUpdate(BaseModel): + """更新漏洞模式""" + status: Optional[VulnerabilityStatus] = None + assigned_to: Optional[str] = None + fix_commit: Optional[str] = None + +class VulnerabilityResponse(VulnerabilityBase): + """漏洞响应模式""" + id: int + scan_id: int + status: VulnerabilityStatus + assigned_to: Optional[str] = None + fix_commit: Optional[str] = None + created_at: datetime + updated_at: Optional[datetime] = None + fixed_at: Optional[datetime] = None + + class Config: + from_attributes = True diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000..92ff558 --- /dev/null +++ b/backend/app/services/__init__.py @@ -0,0 +1 @@ +# 服务层包 diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py new file mode 100644 index 0000000..5639be5 --- /dev/null +++ b/backend/app/services/ai_service.py @@ -0,0 +1,137 @@ +""" +AI增强服务 - 基于现有的DeepSeek集成 +""" +import requests +import json +import time +from typing import Dict, Any, List + +class AIService: + """AI增强服务""" + + def __init__(self): + # 从环境变量或配置文件读取API配置 + self.api_url = "https://api.deepseek.com/v1/chat/completions" + self.api_key = "your_deepseek_api_key_here" # 实际使用时从环境变量获取 + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + async def enhance_vulnerability(self, vulnerability: Dict[str, Any]) -> Dict[str, Any]: + """AI增强漏洞分析""" + try: + # 构建AI分析提示 + prompt = self._build_enhancement_prompt(vulnerability) + + # 调用AI API + ai_response = await self._call_ai_api(prompt) + + # 解析AI响应 + enhancement = self._parse_ai_response(ai_response) + + return { + 'ai_enhanced': True, + 'ai_confidence': enhancement.get('confidence', 0.8), + 'ai_suggestion': enhancement.get('suggestion', ''), + 'ai_explanation': enhancement.get('explanation', '') + } + + except Exception as e: + print(f"AI增强失败: {str(e)}") + return { + 'ai_enhanced': False, + 'ai_confidence': 0.0, + 'ai_suggestion': '', + 'ai_explanation': f'AI分析失败: {str(e)}' + } + + def _build_enhancement_prompt(self, vulnerability: Dict[str, Any]) -> str: + """构建AI分析提示""" + prompt = f""" +请分析以下代码漏洞,并提供详细的修复建议: + +漏洞信息: +- 规则ID: {vulnerability.get('rule_id', 'N/A')} +- 严重程度: {vulnerability.get('severity', 'N/A')} +- 分类: {vulnerability.get('category', 'N/A')} +- 文件路径: {vulnerability.get('file_path', 'N/A')} +- 行号: {vulnerability.get('line_number', 'N/A')} +- 描述: {vulnerability.get('message', 'N/A')} + +相关代码: +```{vulnerability.get('language', 'text')} +{vulnerability.get('code_snippet', '')} +``` + +请提供: +1. 漏洞的详细解释 +2. 可能的修复方案 +3. 修复后的代码示例 +4. 预防类似问题的最佳实践 + +请以JSON格式回复,包含以下字段: +- explanation: 详细解释 +- suggestion: 修复建议 +- fixed_code: 修复后的代码示例 +- best_practices: 最佳实践建议 +- confidence: 分析置信度(0-1) +""" + return prompt + + async def _call_ai_api(self, prompt: str) -> str: + """调用AI API""" + data = { + "model": "deepseek-chat", + "messages": [ + {"role": "system", "content": "你是一个专业的代码安全分析专家。"}, + {"role": "user", "content": prompt} + ], + "temperature": 0.3, + "max_tokens": 2000 + } + + response = requests.post(self.api_url, headers=self.headers, json=data) + response.raise_for_status() + result = response.json() + + return result['choices'][0]['message']['content'] + + def _parse_ai_response(self, response: str) -> Dict[str, Any]: + """解析AI响应""" + try: + # 尝试解析JSON响应 + if response.strip().startswith('{'): + return json.loads(response) + + # 如果不是JSON,返回原始响应 + return { + 'explanation': response, + 'suggestion': '', + 'fixed_code': '', + 'best_practices': '', + 'confidence': 0.7 + } + + except json.JSONDecodeError: + return { + 'explanation': response, + 'suggestion': '', + 'fixed_code': '', + 'best_practices': '', + 'confidence': 0.7 + } + + async def batch_enhance_vulnerabilities(self, vulnerabilities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """批量AI增强漏洞""" + enhanced_vulnerabilities = [] + + for vulnerability in vulnerabilities: + enhancement = await self.enhance_vulnerability(vulnerability) + vulnerability.update(enhancement) + enhanced_vulnerabilities.append(vulnerability) + + # 避免API请求过快 + time.sleep(0.5) + + return enhanced_vulnerabilities diff --git a/backend/app/services/analyzer_service.py b/backend/app/services/analyzer_service.py new file mode 100644 index 0000000..c579a4d --- /dev/null +++ b/backend/app/services/analyzer_service.py @@ -0,0 +1,60 @@ +""" +代码分析服务 +""" +import os +import json +import subprocess +from typing import List, Dict, Any +from app.core.analyzers.base_analyzer import BaseAnalyzer +from app.core.analyzers.python_analyzer import PythonAnalyzer +from app.core.analyzers.cpp_analyzer import CppAnalyzer +from app.core.analyzers.javascript_analyzer import JavaScriptAnalyzer +from app.services.ai_service import AIService + +class AnalyzerService: + """代码分析服务""" + + def __init__(self): + self.analyzers = { + 'python': PythonAnalyzer(), + 'cpp': CppAnalyzer(), + 'javascript': JavaScriptAnalyzer(), + } + self.ai_service = AIService() + + async def analyze_project(self, project_path: str, language: str, config: Dict[str, Any] = None) -> List[Dict[str, Any]]: + """分析项目代码""" + if language not in self.analyzers: + raise ValueError(f"不支持的语言: {language}") + + analyzer = self.analyzers[language] + + # 运行静态分析 + static_results = await analyzer.analyze(project_path, config) + + # AI增强分析 + ai_enhanced_results = [] + for result in static_results: + # 对每个漏洞进行AI增强 + ai_enhancement = await self.ai_service.enhance_vulnerability(result) + result.update(ai_enhancement) + ai_enhanced_results.append(result) + + return ai_enhanced_results + + def get_supported_languages(self) -> List[str]: + """获取支持的语言列表""" + return list(self.analyzers.keys()) + + def get_analyzer_info(self, language: str) -> Dict[str, Any]: + """获取分析器信息""" + if language not in self.analyzers: + return None + + analyzer = self.analyzers[language] + return { + 'name': analyzer.name, + 'version': analyzer.version, + 'supported_extensions': analyzer.supported_extensions, + 'description': analyzer.description + } diff --git a/backend/app/services/dashboard_service.py b/backend/app/services/dashboard_service.py new file mode 100644 index 0000000..8584b2d --- /dev/null +++ b/backend/app/services/dashboard_service.py @@ -0,0 +1,120 @@ +""" +仪表板服务 +""" +from sqlalchemy.orm import Session +from sqlalchemy import func, desc +from app.models.project import Project +from app.models.scan import Scan +from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel + +class DashboardService: + """仪表板服务""" + + async def get_summary_data(self, db: Session) -> dict: + """获取仪表板汇总数据""" + # 项目统计 + total_projects = db.query(Project).filter(Project.is_active == True).count() + + # 扫描统计 + total_scans = db.query(Scan).count() + completed_scans = db.query(Scan).filter(Scan.status == "completed").count() + + # 漏洞统计 + total_vulnerabilities = db.query(Vulnerability).count() + fixed_vulnerabilities = db.query(Vulnerability).filter( + Vulnerability.status == VulnerabilityStatus.FIXED + ).count() + + # 按严重程度统计 + severity_stats = db.query( + Vulnerability.severity, + func.count(Vulnerability.id).label('count') + ).group_by(Vulnerability.severity).all() + + severity_summary = {} + for severity, count in severity_stats: + severity_summary[severity.value] = count + + # 最近发现的漏洞 + recent_vulnerabilities = db.query(Vulnerability).order_by( + desc(Vulnerability.created_at) + ).limit(10).all() + + recent_vuln_data = [] + for vuln in recent_vulnerabilities: + recent_vuln_data.append({ + 'id': vuln.id, + 'project_name': vuln.scan.project.name if vuln.scan and vuln.scan.project else 'Unknown', + 'category': vuln.category.value, + 'severity': vuln.severity.value, + 'file_path': vuln.file_path, + 'message': vuln.message, + 'created_at': vuln.created_at.isoformat() + }) + + # 项目漏洞统计 + project_stats = db.query( + Project.name, + func.count(Vulnerability.id).label('vulnerability_count') + ).join(Scan, Project.id == Scan.project_id).join( + Vulnerability, Scan.id == Vulnerability.scan_id + ).group_by(Project.id, Project.name).order_by( + desc('vulnerability_count') + ).limit(5).all() + + project_vuln_data = [] + for project_name, count in project_stats: + project_vuln_data.append({ + 'project_name': project_name, + 'vulnerability_count': count + }) + + return { + 'projects': total_projects, + 'scans': total_scans, + 'completed_scans': completed_scans, + 'vulnerabilities': total_vulnerabilities, + 'fixed': fixed_vulnerabilities, + 'severity_summary': severity_summary, + 'recent_vulnerabilities': recent_vuln_data, + 'project_stats': project_vuln_data + } + + async def get_trend_data(self, db: Session, days: int = 30) -> dict: + """获取趋势数据""" + from datetime import datetime, timedelta + + end_date = datetime.now() + start_date = end_date - timedelta(days=days) + + # 每日扫描统计 + daily_scans = db.query( + func.date(Scan.created_at).label('date'), + func.count(Scan.id).label('count') + ).filter( + Scan.created_at >= start_date + ).group_by(func.date(Scan.created_at)).all() + + # 每日漏洞统计 + daily_vulnerabilities = db.query( + func.date(Vulnerability.created_at).label('date'), + func.count(Vulnerability.id).label('count') + ).filter( + Vulnerability.created_at >= start_date + ).group_by(func.date(Vulnerability.created_at)).all() + + return { + 'daily_scans': [{'date': str(date), 'count': count} for date, count in daily_scans], + 'daily_vulnerabilities': [{'date': str(date), 'count': count} for date, count in daily_vulnerabilities] + } + + async def get_category_distribution(self, db: Session) -> dict: + """获取漏洞分类分布""" + category_stats = db.query( + Vulnerability.category, + func.count(Vulnerability.id).label('count') + ).group_by(Vulnerability.category).all() + + return { + category.value: count for category, count in category_stats + } diff --git a/backend/app/services/report_service.py b/backend/app/services/report_service.py new file mode 100644 index 0000000..902b068 --- /dev/null +++ b/backend/app/services/report_service.py @@ -0,0 +1,220 @@ +""" +报告生成服务 +""" +import os +import json +import pandas as pd +from jinja2 import Template +from typing import Dict, Any +from datetime import datetime +from app.models.scan import Scan +from app.models.project import Project + +class ReportService: + """报告生成服务""" + + def __init__(self): + self.templates_dir = "app/templates" + self.reports_dir = "reports" + os.makedirs(self.reports_dir, exist_ok=True) + + async def generate_html_report(self, scan: Scan) -> str: + """生成HTML报告""" + # 准备报告数据 + report_data = await self._prepare_report_data(scan) + + # 读取HTML模板 + template_path = os.path.join(self.templates_dir, "scan_report.html") + with open(template_path, 'r', encoding='utf-8') as f: + template_content = f.read() + + # 渲染模板 + template = Template(template_content) + html_content = template.render(**report_data) + + # 保存HTML文件 + report_filename = f"scan_report_{scan.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html" + report_path = os.path.join(self.reports_dir, report_filename) + + with open(report_path, 'w', encoding='utf-8') as f: + f.write(html_content) + + return report_path + + async def generate_pdf_report(self, scan: Scan) -> str: + """生成PDF报告""" + # 先生成HTML报告 + html_path = await self.generate_html_report(scan) + + # 按需导入WeasyPrint,并给出友好降级 + try: + from weasyprint import HTML # type: ignore + except Exception as exc: # ImportError 或底层依赖缺失 + raise RuntimeError( + "PDF 导出所需依赖缺失(WeasyPrint 及其系统库)。" \ + "请先使用 HTML/Excel/JSON 导出,或按安装指南配置 WeasyPrint。原始错误: " + str(exc) + ) + + # 转换为PDF + pdf_filename = f"scan_report_{scan.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf" + pdf_path = os.path.join(self.reports_dir, pdf_filename) + + HTML(filename=html_path).write_pdf(pdf_path) + + return pdf_path + + async def generate_json_report(self, scan: Scan) -> Dict[str, Any]: + """生成JSON报告""" + report_data = await self._prepare_report_data(scan) + return report_data + + async def generate_excel_report(self, scan: Scan) -> str: + """生成Excel报告""" + # 获取漏洞数据 + vulnerabilities = scan.vulnerabilities + + # 准备Excel数据 + excel_data = [] + for vuln in vulnerabilities: + excel_data.append({ + 'ID': vuln.id, + '规则ID': vuln.rule_id, + '严重程度': vuln.severity.value, + '分类': vuln.category.value, + '文件路径': vuln.file_path, + '行号': vuln.line_number, + '描述': vuln.message, + 'AI增强': '是' if vuln.ai_enhanced else '否', + 'AI置信度': vuln.ai_confidence, + 'AI建议': vuln.ai_suggestion, + '状态': vuln.status.value, + '创建时间': vuln.created_at.strftime('%Y-%m-%d %H:%M:%S') + }) + + # 创建DataFrame + df = pd.DataFrame(excel_data) + + # 保存Excel文件 + excel_filename = f"scan_report_{scan.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx" + excel_path = os.path.join(self.reports_dir, excel_filename) + + with pd.ExcelWriter(excel_path, engine='openpyxl') as writer: + df.to_excel(writer, sheet_name='漏洞详情', index=False) + + # 添加统计信息表 + stats_data = await self._generate_stats_data(scan) + stats_df = pd.DataFrame(stats_data) + stats_df.to_excel(writer, sheet_name='统计信息', index=False) + + return excel_path + + async def generate_project_html_report(self, project: Project, latest_scan: Scan) -> str: + """生成项目汇总报告""" + # 准备项目报告数据 + report_data = await self._prepare_project_report_data(project, latest_scan) + + # 读取项目报告模板 + template_path = os.path.join(self.templates_dir, "project_report.html") + with open(template_path, 'r', encoding='utf-8') as f: + template_content = f.read() + + # 渲染模板 + template = Template(template_content) + html_content = template.render(**report_data) + + # 保存HTML文件 + report_filename = f"project_report_{project.id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html" + report_path = os.path.join(self.reports_dir, report_filename) + + with open(report_path, 'w', encoding='utf-8') as f: + f.write(html_content) + + return report_path + + async def generate_project_json_report(self, project: Project, latest_scan: Scan) -> Dict[str, Any]: + """生成项目JSON报告""" + report_data = await self._prepare_project_report_data(project, latest_scan) + return report_data + + async def _prepare_report_data(self, scan: Scan) -> Dict[str, Any]: + """准备报告数据""" + vulnerabilities = scan.vulnerabilities + + # 按严重程度分组 + by_severity = {} + by_category = {} + + for vuln in vulnerabilities: + severity = vuln.severity.value + category = vuln.category.value + + if severity not in by_severity: + by_severity[severity] = [] + by_severity[severity].append(vuln) + + if category not in by_category: + by_category[category] = [] + by_category[category].append(vuln) + + return { + 'scan': scan, + 'project': scan.project, + 'vulnerabilities': vulnerabilities, + 'by_severity': by_severity, + 'by_category': by_category, + 'total_vulnerabilities': len(vulnerabilities), + 'generated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + + async def _prepare_project_report_data(self, project: Project, latest_scan: Scan) -> Dict[str, Any]: + """准备项目报告数据""" + vulnerabilities = latest_scan.vulnerabilities + + # 统计信息 + total_vulnerabilities = len(vulnerabilities) + critical_count = len([v for v in vulnerabilities if v.severity.value == 'critical']) + high_count = len([v for v in vulnerabilities if v.severity.value == 'high']) + medium_count = len([v for v in vulnerabilities if v.severity.value == 'medium']) + low_count = len([v for v in vulnerabilities if v.severity.value == 'low']) + + return { + 'project': project, + 'latest_scan': latest_scan, + 'vulnerabilities': vulnerabilities, + 'stats': { + 'total': total_vulnerabilities, + 'critical': critical_count, + 'high': high_count, + 'medium': medium_count, + 'low': low_count + }, + 'generated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + + async def _generate_stats_data(self, scan: Scan) -> list: + """生成统计信息数据""" + vulnerabilities = scan.vulnerabilities + + # 按严重程度统计 + severity_stats = {} + for vuln in vulnerabilities: + severity = vuln.severity.value + severity_stats[severity] = severity_stats.get(severity, 0) + 1 + + # 按分类统计 + category_stats = {} + for vuln in vulnerabilities: + category = vuln.category.value + category_stats[category] = category_stats.get(category, 0) + 1 + + stats_data = [] + stats_data.append(['统计类型', '分类', '数量']) + stats_data.append(['严重程度', '总计', len(vulnerabilities)]) + + for severity, count in severity_stats.items(): + stats_data.append(['严重程度', severity, count]) + + for category, count in category_stats.items(): + stats_data.append(['分类', category, count]) + + return stats_data diff --git a/backend/app/services/scan_service.py b/backend/app/services/scan_service.py new file mode 100644 index 0000000..82b96cd --- /dev/null +++ b/backend/app/services/scan_service.py @@ -0,0 +1,162 @@ +""" +扫描服务 +""" +import asyncio +from typing import List +from sqlalchemy.orm import Session +from app.database import SessionLocal +from app.models.scan import Scan, ScanStatus +from app.models.vulnerability import Vulnerability, VulnerabilityStatus, SeverityLevel, VulnerabilityCategory +from app.services.analyzer_service import AnalyzerService +from app.services.report_service import ReportService +import json + +class ScanService: + """扫描服务""" + + def __init__(self): + self.analyzer_service = AnalyzerService() + self.report_service = ReportService() + + async def run_scan(self, scan_id: int): + """运行扫描任务""" + db = SessionLocal() + try: + # 获取扫描任务 + scan = db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + print(f"扫描任务不存在: {scan_id}") + return + + # 更新扫描状态为运行中 + scan.status = ScanStatus.RUNNING + scan.started_at = asyncio.get_event_loop().time() + db.commit() + + try: + # 获取项目信息 + project = scan.project + if not project: + raise Exception("项目不存在") + + # 解析扫描配置 + scan_config = {} + if scan.scan_config: + scan_config = json.loads(scan.scan_config) + + # 执行代码分析 + vulnerabilities_data = await self.analyzer_service.analyze_project( + project_path=project.project_path, + language=project.language, + config=scan_config + ) + + # 保存漏洞数据到数据库 + await self._save_vulnerabilities(db, scan, vulnerabilities_data) + + # 更新扫描统计 + scan.total_files = scan_config.get('total_files', 100) # 模拟文件数 + scan.scanned_files = scan.total_files + scan.total_vulnerabilities = len(vulnerabilities_data) + scan.status = ScanStatus.COMPLETED + scan.completed_at = asyncio.get_event_loop().time() + + # 生成结果摘要 + summary = { + 'total_vulnerabilities': scan.total_vulnerabilities, + 'by_severity': {}, + 'by_category': {} + } + + for vuln_data in vulnerabilities_data: + severity = vuln_data.get('severity', 'medium') + category = vuln_data.get('category', 'maintainability') + + summary['by_severity'][severity] = summary['by_severity'].get(severity, 0) + 1 + summary['by_category'][category] = summary['by_category'].get(category, 0) + 1 + + scan.result_summary = json.dumps(summary) + db.commit() + + print(f"扫描完成: {scan_id}, 发现 {scan.total_vulnerabilities} 个漏洞") + + except Exception as e: + # 扫描失败 + scan.status = ScanStatus.FAILED + scan.error_message = str(e) + db.commit() + print(f"扫描失败: {scan_id}, 错误: {str(e)}") + + finally: + db.close() + + async def _save_vulnerabilities(self, db: Session, scan: Scan, vulnerabilities_data: List[dict]): + """保存漏洞数据到数据库""" + for vuln_data in vulnerabilities_data: + # 映射严重程度 + severity_mapping = { + 'critical': SeverityLevel.CRITICAL, + 'high': SeverityLevel.HIGH, + 'medium': SeverityLevel.MEDIUM, + 'low': SeverityLevel.LOW, + 'info': SeverityLevel.INFO + } + + # 映射分类 + category_mapping = { + 'security': VulnerabilityCategory.SECURITY, + 'performance': VulnerabilityCategory.PERFORMANCE, + 'maintainability': VulnerabilityCategory.MAINTAINABILITY, + 'reliability': VulnerabilityCategory.RELIABILITY, + 'usability': VulnerabilityCategory.USABILITY + } + + vulnerability = Vulnerability( + scan_id=scan.id, + rule_id=vuln_data.get('rule_id', 'unknown'), + message=vuln_data.get('message', ''), + category=category_mapping.get(vuln_data.get('category', 'maintainability'), VulnerabilityCategory.MAINTAINABILITY), + severity=severity_mapping.get(vuln_data.get('severity', 'medium'), SeverityLevel.MEDIUM), + file_path=vuln_data.get('file_path', ''), + line_number=vuln_data.get('line_number'), + column_number=vuln_data.get('column_number'), + end_line=vuln_data.get('end_line'), + end_column=vuln_data.get('end_column'), + code_snippet=vuln_data.get('code_snippet', ''), + context_before=vuln_data.get('context_before', ''), + context_after=vuln_data.get('context_after', ''), + ai_enhanced=vuln_data.get('ai_enhanced', False), + ai_confidence=vuln_data.get('ai_confidence'), + ai_suggestion=vuln_data.get('ai_suggestion', ''), + status=VulnerabilityStatus.OPEN + ) + + db.add(vulnerability) + + db.commit() + + def get_scan_progress(self, scan_id: int) -> dict: + """获取扫描进度""" + db = SessionLocal() + try: + scan = db.query(Scan).filter(Scan.id == scan_id).first() + if not scan: + return {'error': '扫描任务不存在'} + + progress = 0 + if scan.total_files > 0: + progress = (scan.scanned_files / scan.total_files) * 100 + + return { + 'scan_id': scan_id, + 'status': scan.status.value, + 'progress': progress, + 'total_files': scan.total_files, + 'scanned_files': scan.scanned_files, + 'total_vulnerabilities': scan.total_vulnerabilities, + 'started_at': scan.started_at, + 'completed_at': scan.completed_at, + 'error_message': scan.error_message + } + finally: + db.close() diff --git a/backend/app/templates/scan_report.html b/backend/app/templates/scan_report.html new file mode 100644 index 0000000..4a96722 --- /dev/null +++ b/backend/app/templates/scan_report.html @@ -0,0 +1,247 @@ + + + + + + 代码扫描报告 - {{ project.name }} + + + +
+
+

代码扫描报告

+

项目: {{ project.name }} | 生成时间: {{ generated_at }}

+
+ + +
+

扫描摘要

+
+
+

总漏洞数

+
{{ total_vulnerabilities }}
+
+ {% for severity, vulns in by_severity.items() %} +
+

{{ severity|title }} 漏洞

+
{{ vulns|length }}
+
+ {% endfor %} +
+
+ + +
+

漏洞详情

+ {% for vulnerability in vulnerabilities %} +
+
+
+ {{ vulnerability.rule_id }}: {{ vulnerability.message }} +
+ + {{ vulnerability.severity.value|upper }} + +
+
+
+
+ 文件: + {{ vulnerability.file_path }} +
+
+ 行号: + {{ vulnerability.line_number or 'N/A' }} +
+
+ 分类: + {{ vulnerability.category.value }} +
+
+ 状态: + {{ vulnerability.status.value }} +
+
+ + {% if vulnerability.code_snippet %} +
+ 相关代码: +
{{ vulnerability.code_snippet }}
+
+ {% endif %} + + {% if vulnerability.ai_enhanced and vulnerability.ai_suggestion %} +
+

🤖 AI 建议

+

{{ vulnerability.ai_suggestion }}

+ {% if vulnerability.ai_confidence %} + 置信度: {{ (vulnerability.ai_confidence * 100)|round(1) }}% + {% endif %} +
+ {% endif %} +
+
+ {% endfor %} +
+ + +
+ + diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py new file mode 100644 index 0000000..1985572 --- /dev/null +++ b/backend/app/utils/__init__.py @@ -0,0 +1 @@ +# 工具类包 diff --git a/backend/app/utils/config.py b/backend/app/utils/config.py new file mode 100644 index 0000000..08e4883 --- /dev/null +++ b/backend/app/utils/config.py @@ -0,0 +1,33 @@ +""" +配置管理 +""" +import os +from typing import Optional + +class Config: + """配置类""" + + # 数据库配置 + DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./code_scanner.db") + + # AI服务配置 + DEEPSEEK_API_URL: str = os.getenv("DEEPSEEK_API_URL", "https://api.deepseek.com/v1/chat/completions") + DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY", "your_deepseek_api_key_here") + + # 文件上传配置 + UPLOAD_FOLDER: str = os.getenv("UPLOAD_FOLDER", "uploads") + MAX_CONTENT_LENGTH: int = int(os.getenv("MAX_CONTENT_LENGTH", "16 * 1024 * 1024")) # 16MB + + # 扫描配置 + MAX_SCAN_FILES: int = int(os.getenv("MAX_SCAN_FILES", "1000")) + SCAN_TIMEOUT: int = int(os.getenv("SCAN_TIMEOUT", "300")) # 5分钟 + + # 报告配置 + REPORTS_FOLDER: str = os.getenv("REPORTS_FOLDER", "reports") + + # 日志配置 + LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") + LOG_FILE: str = os.getenv("LOG_FILE", "app.log") + +# 创建配置实例 +config = Config() diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000..a028bd0 --- /dev/null +++ b/backend/main.py @@ -0,0 +1,55 @@ +""" +代码漏洞检测系统 - 后端主启动文件 +""" +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from app.api import projects, scans, reports, vulnerabilities, files +from app.database import init_db + +# 创建FastAPI应用 +app = FastAPI( + title="代码漏洞检测系统", + description="基于AI增强的代码漏洞检测和报告生成系统", + version="1.0.0" +) + +# 配置CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:3000"], # 前端地址 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 注册路由 +app.include_router(projects.router, prefix="/api/projects", tags=["projects"]) +app.include_router(scans.router, prefix="/api/scans", tags=["scans"]) +app.include_router(reports.router, prefix="/api/reports", tags=["reports"]) +app.include_router(vulnerabilities.router, prefix="/api/vulnerabilities", tags=["vulnerabilities"]) +app.include_router(files.router, prefix="/api", tags=["files"]) + +@app.on_event("startup") +async def startup_event(): + """应用启动时初始化数据库""" + init_db() + +@app.get("/") +async def root(): + """根路径健康检查""" + return {"message": "代码漏洞检测系统 API", "status": "running"} + +@app.get("/health") +async def health_check(): + """健康检查接口""" + return {"status": "healthy"} + +if __name__ == "__main__": + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + reload=True, + log_level="info" + ) diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..66f34ad --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,12 @@ +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +sqlalchemy==2.0.23 +pydantic==2.5.0 +python-multipart==0.0.6 +requests==2.31.0 +python-dotenv==1.0.0 +alembic==1.12.1 +pandas==2.1.4 +jinja2==3.1.2 +weasyprint==60.2 +openpyxl==3.1.2