You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

246 lines
8.5 KiB

"""
文件管理API路由
"""
import os
import mimetypes
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from typing import List
from app.database import get_db
from app.models.project import Project
import zipfile
import tempfile
import shutil
router = APIRouter()
@router.get("/projects/{project_id}/files")
async def get_project_files(
project_id: int,
path: str = "",
db: Session = Depends(get_db)
):
"""获取项目文件列表"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
if not project.project_path or not os.path.exists(project.project_path):
raise HTTPException(status_code=404, detail="项目路径不存在")
target_path = os.path.join(project.project_path, path) if path else project.project_path
if not os.path.exists(target_path):
raise HTTPException(status_code=404, detail="路径不存在")
if not os.path.isdir(target_path):
raise HTTPException(status_code=400, detail="路径不是目录")
files = []
directories = []
try:
for item in os.listdir(target_path):
item_path = os.path.join(target_path, item)
# 跳过隐藏文件和常见的非代码文件
if item.startswith('.') or item in ['node_modules', '__pycache__', '.git', 'venv', 'env']:
continue
item_info = {
'name': item,
'path': os.path.join(path, item).replace('\\', '/'),
'is_directory': os.path.isdir(item_path),
'size': os.path.getsize(item_path) if os.path.isfile(item_path) else 0,
'modified': os.path.getmtime(item_path)
}
if item_info['is_directory']:
directories.append(item_info)
else:
# 只显示代码文件
if _is_code_file(item):
files.append(item_info)
# 排序:目录在前,文件在后,按名称排序
directories.sort(key=lambda x: x['name'].lower())
files.sort(key=lambda x: x['name'].lower())
return {
'files': directories + files,
'current_path': path,
'project_path': project.project_path
}
except PermissionError:
raise HTTPException(status_code=403, detail="权限不足")
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取文件列表失败: {str(e)}")
@router.get("/projects/{project_id}/files/content")
async def get_file_content(
project_id: int,
file_path: str,
db: Session = Depends(get_db)
):
"""获取文件内容"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
full_path = os.path.join(project.project_path, file_path)
if not os.path.exists(full_path) or not os.path.isfile(full_path):
raise HTTPException(status_code=404, detail="文件不存在")
if not _is_code_file(full_path):
raise HTTPException(status_code=400, detail="不支持的文件类型")
try:
# 尝试不同的编码
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
content = None
for encoding in encodings:
try:
with open(full_path, 'r', encoding=encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
if content is None:
raise HTTPException(status_code=400, detail="文件编码不支持")
return {
'content': content,
'file_path': file_path,
'size': len(content.encode('utf-8')),
'lines': len(content.splitlines())
}
except PermissionError:
raise HTTPException(status_code=403, detail="权限不足")
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取文件失败: {str(e)}")
@router.post("/projects/{project_id}/files/content")
async def save_file_content(
project_id: int,
file_path: str,
content: str,
db: Session = Depends(get_db)
):
"""保存文件内容"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
full_path = os.path.join(project.project_path, file_path)
if not os.path.exists(full_path) or not os.path.isfile(full_path):
raise HTTPException(status_code=404, detail="文件不存在")
try:
# 备份原文件
backup_path = f"{full_path}.backup"
shutil.copy2(full_path, backup_path)
# 写入新内容
with open(full_path, 'w', encoding='utf-8') as f:
f.write(content)
return {"message": "文件保存成功", "backup": backup_path}
except PermissionError:
raise HTTPException(status_code=403, detail="权限不足")
except Exception as e:
raise HTTPException(status_code=500, detail=f"保存文件失败: {str(e)}")
@router.post("/projects/{project_id}/files/upload")
async def upload_files(
project_id: int,
files: List[UploadFile] = File(...),
db: Session = Depends(get_db)
):
"""上传文件到项目"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
uploaded_files = []
try:
for file in files:
if not _is_code_file(file.filename):
continue
file_path = os.path.join(project.project_path, file.filename)
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
uploaded_files.append({
'filename': file.filename,
'size': len(content)
})
return {"message": f"成功上传 {len(uploaded_files)} 个文件", "files": uploaded_files}
except Exception as e:
raise HTTPException(status_code=500, detail=f"上传文件失败: {str(e)}")
@router.get("/projects/{project_id}/files/download")
async def download_project(
project_id: int,
db: Session = Depends(get_db)
):
"""下载整个项目为ZIP文件"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
if not os.path.exists(project.project_path):
raise HTTPException(status_code=404, detail="项目路径不存在")
try:
# 创建临时ZIP文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
with zipfile.ZipFile(temp_file.name, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(project.project_path):
# 跳过不必要的目录
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__']]
for file in files:
if _is_code_file(file):
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, project.project_path)
zipf.write(file_path, arcname)
return FileResponse(
temp_file.name,
media_type='application/zip',
filename=f"{project.name}.zip"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"创建下载文件失败: {str(e)}")
def _is_code_file(filename: str) -> bool:
"""判断是否为代码文件"""
code_extensions = {
'.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.h', '.hpp',
'.cs', '.php', '.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.sh',
'.sql', '.html', '.css', '.scss', '.sass', '.less', '.vue', '.json',
'.xml', '.yaml', '.yml', '.md', '.txt'
}
if isinstance(filename, str):
_, ext = os.path.splitext(filename.lower())
return ext in code_extensions
return False