""" 文件管理API路由 """ import os import mimetypes from fastapi import APIRouter, Depends, HTTPException, UploadFile, File from fastapi.responses import FileResponse, StreamingResponse from sqlalchemy.orm import Session from typing import List from app.database import get_db from app.models.project import Project import zipfile import tempfile import shutil router = APIRouter() @router.get("/projects/{project_id}/files") async def get_project_files( project_id: int, path: str = "", db: Session = Depends(get_db) ): """获取项目文件列表""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="项目不存在") if not project.project_path or not os.path.exists(project.project_path): raise HTTPException(status_code=404, detail="项目路径不存在") target_path = os.path.join(project.project_path, path) if path else project.project_path if not os.path.exists(target_path): raise HTTPException(status_code=404, detail="路径不存在") if not os.path.isdir(target_path): raise HTTPException(status_code=400, detail="路径不是目录") files = [] directories = [] try: for item in os.listdir(target_path): item_path = os.path.join(target_path, item) # 跳过隐藏文件和常见的非代码文件 if item.startswith('.') or item in ['node_modules', '__pycache__', '.git', 'venv', 'env']: continue item_info = { 'name': item, 'path': os.path.join(path, item).replace('\\', '/'), 'is_directory': os.path.isdir(item_path), 'size': os.path.getsize(item_path) if os.path.isfile(item_path) else 0, 'modified': os.path.getmtime(item_path) } if item_info['is_directory']: directories.append(item_info) else: # 只显示代码文件 if _is_code_file(item): files.append(item_info) # 排序:目录在前,文件在后,按名称排序 directories.sort(key=lambda x: x['name'].lower()) files.sort(key=lambda x: x['name'].lower()) return { 'files': directories + files, 'current_path': path, 'project_path': project.project_path } except PermissionError: raise HTTPException(status_code=403, detail="权限不足") except Exception as e: raise HTTPException(status_code=500, detail=f"读取文件列表失败: {str(e)}") @router.get("/projects/{project_id}/files/content") async def get_file_content( project_id: int, file_path: str, db: Session = Depends(get_db) ): """获取文件内容""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="项目不存在") full_path = os.path.join(project.project_path, file_path) if not os.path.exists(full_path) or not os.path.isfile(full_path): raise HTTPException(status_code=404, detail="文件不存在") if not _is_code_file(full_path): raise HTTPException(status_code=400, detail="不支持的文件类型") try: # 尝试不同的编码 encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1'] content = None for encoding in encodings: try: with open(full_path, 'r', encoding=encoding) as f: content = f.read() break except UnicodeDecodeError: continue if content is None: raise HTTPException(status_code=400, detail="文件编码不支持") return { 'content': content, 'file_path': file_path, 'size': len(content.encode('utf-8')), 'lines': len(content.splitlines()) } except PermissionError: raise HTTPException(status_code=403, detail="权限不足") except Exception as e: raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}") @router.post("/projects/{project_id}/files/content") async def save_file_content( project_id: int, file_path: str, content: str, db: Session = Depends(get_db) ): """保存文件内容""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="项目不存在") full_path = os.path.join(project.project_path, file_path) if not os.path.exists(full_path) or not os.path.isfile(full_path): raise HTTPException(status_code=404, detail="文件不存在") try: # 备份原文件 backup_path = f"{full_path}.backup" shutil.copy2(full_path, backup_path) # 写入新内容 with open(full_path, 'w', encoding='utf-8') as f: f.write(content) return {"message": "文件保存成功", "backup": backup_path} except PermissionError: raise HTTPException(status_code=403, detail="权限不足") except Exception as e: raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}") @router.post("/projects/{project_id}/files/upload") async def upload_files( project_id: int, files: List[UploadFile] = File(...), db: Session = Depends(get_db) ): """上传文件到项目""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="项目不存在") uploaded_files = [] try: for file in files: if not _is_code_file(file.filename): continue file_path = os.path.join(project.project_path, file.filename) with open(file_path, "wb") as buffer: content = await file.read() buffer.write(content) uploaded_files.append({ 'filename': file.filename, 'size': len(content) }) return {"message": f"成功上传 {len(uploaded_files)} 个文件", "files": uploaded_files} except Exception as e: raise HTTPException(status_code=500, detail=f"上传文件失败: {str(e)}") @router.get("/projects/{project_id}/files/download") async def download_project( project_id: int, db: Session = Depends(get_db) ): """下载整个项目为ZIP文件""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="项目不存在") if not os.path.exists(project.project_path): raise HTTPException(status_code=404, detail="项目路径不存在") try: # 创建临时ZIP文件 temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') with zipfile.ZipFile(temp_file.name, 'w', zipfile.ZIP_DEFLATED) as zipf: for root, dirs, files in os.walk(project.project_path): # 跳过不必要的目录 dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__']] for file in files: if _is_code_file(file): file_path = os.path.join(root, file) arcname = os.path.relpath(file_path, project.project_path) zipf.write(file_path, arcname) return FileResponse( temp_file.name, media_type='application/zip', filename=f"{project.name}.zip" ) except Exception as e: raise HTTPException(status_code=500, detail=f"创建下载文件失败: {str(e)}") def _is_code_file(filename: str) -> bool: """判断是否为代码文件""" code_extensions = { '.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp', '.cs', '.php', '.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.sh', '.sql', '.html', '.css', '.scss', '.sass', '.less', '.vue', '.json', '.xml', '.yaml', '.yml', '.md', '.txt' } if isinstance(filename, str): _, ext = os.path.splitext(filename.lower()) return ext in code_extensions return False