You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ErrorDetecting/backend/app/routers/auth.py

136 lines
6.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, func, text
from ..db import get_db
from ..models.users import User
from passlib.hash import bcrypt
from ..config import JWT_SECRET, JWT_EXPIRE_MINUTES
import jwt
from datetime import datetime, timedelta, timezone
import re
router = APIRouter()
class LoginRequest(BaseModel):
username: str
password: str
class RegisterRequest(BaseModel):
username: str
email: str
password: str
fullName: str
async def _get_user_id(db: AsyncSession, username: str) -> int | None:
res = await db.execute(text("SELECT id FROM users WHERE username=:u LIMIT 1"), {"u": username})
row = res.first()
return row[0] if row else None
async def _get_role_id(db: AsyncSession, role_key: str) -> int | None:
res = await db.execute(text("SELECT id FROM roles WHERE role_key=:k LIMIT 1"), {"k": role_key})
row = res.first()
return row[0] if row else None
async def _ensure_observer_role(db: AsyncSession) -> int:
rid = await _get_role_id(db, "observer")
if rid is not None:
return rid
await db.execute(
text(
"INSERT INTO roles(role_name, role_key, description, is_system_role, created_at, updated_at) VALUES(:rn, :rk, :desc, TRUE, NOW(), NOW())"
),
{"rn": "观察员", "rk": "observer", "desc": "系统默认观察员角色"},
)
await db.commit()
rid2 = await _get_role_id(db, "observer")
if rid2 is None:
raise HTTPException(status_code=500, detail="role_init_failed")
return rid2
async def _map_user_role(db: AsyncSession, username: str, role_key: str) -> None:
uid = await _get_user_id(db, username)
if uid is None:
raise HTTPException(status_code=500, detail="user_not_found_after_register")
rid = await _get_role_id(db, role_key)
if rid is None:
if role_key == "observer":
rid = await _ensure_observer_role(db)
else:
raise HTTPException(status_code=400, detail="role_not_exist")
await db.execute(text("DELETE FROM user_role_mapping WHERE user_id=:uid"), {"uid": uid})
await db.execute(text("INSERT INTO user_role_mapping(user_id, role_id) VALUES(:uid, :rid)"), {"uid": uid, "rid": rid})
await db.commit()
@router.post("/user/login")
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
demo = {"admin": "admin123", "ops": "ops123", "obs": "obs123"}
if req.username in demo and req.password == demo[req.username]:
exp = datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRE_MINUTES)
token = jwt.encode({"sub": req.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
return {"ok": True, "username": req.username, "fullName": req.username, "token": token}
try:
result = await db.execute(select(User).where(User.username == req.username).limit(1))
user = result.scalars().first()
if not user:
raise HTTPException(status_code=401, detail="invalid_credentials")
if not user.is_active:
raise HTTPException(status_code=403, detail="inactive_user")
if not bcrypt.verify(req.password, user.password_hash):
raise HTTPException(status_code=401, detail="invalid_credentials")
await db.execute(
update(User).where(User.id == user.id).values(last_login=func.now(), updated_at=func.now())
)
await db.commit()
exp = datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRE_MINUTES)
token = jwt.encode({"sub": user.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
return {"ok": True, "username": user.username, "fullName": user.full_name, "token": token}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="server_error")
@router.post("/user/register")
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
try:
errors: list[dict] = []
if not (3 <= len(req.username) <= 50) or not re.fullmatch(r"^[A-Za-z][A-Za-z0-9_]{2,49}$", req.username or ""):
errors.append({"field": "username", "code": "invalid_username", "message": "用户名需以字母开头,支持字母/数字/下划线长度3-50"})
if not re.fullmatch(r"^[^@\s]+@[^@\s]+\.[^@\s]+", req.email or ""):
errors.append({"field": "email", "code": "invalid_email", "message": "邮箱格式不正确"})
if not (8 <= len(req.password) <= 128) or not re.search(r"[A-Z]", req.password) or not re.search(r"[a-z]", req.password) or not re.search(r"\d", req.password):
errors.append({"field": "password", "code": "weak_password", "message": "密码至少8位需包含大小写字母与数字"})
if not (2 <= len(req.fullName) <= 100):
errors.append({"field": "fullName", "code": "invalid_full_name", "message": "姓名长度需在2-100之间"})
if errors:
raise HTTPException(status_code=400, detail={"errors": errors})
exists_username = await db.execute(select(User).where(User.username == req.username).limit(1))
if exists_username.scalars().first():
raise HTTPException(status_code=409, detail="user_exists")
exists_email = await db.execute(select(User.id).where(User.email == req.email).limit(1))
if exists_email.scalars().first():
raise HTTPException(status_code=409, detail="email_exists")
password_hash = bcrypt.hash(req.password)
user = User(
username=req.username,
email=req.email,
password_hash=password_hash,
full_name=req.fullName,
is_active=True,
last_login=None,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db.add(user)
await db.flush()
await db.commit()
await _map_user_role(db, req.username, "observer")
exp = datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRE_MINUTES)
token = jwt.encode({"sub": user.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
return {"ok": True, "username": user.username, "fullName": user.full_name, "token": token}
except HTTPException:
raise
except Exception as e:
print(f"DEBUG: Database error: {str(e)}")
raise HTTPException(status_code=500, detail=f"server_error: {str(e)}")