You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ErrorDetecting/backend/app/routers/auth.py

213 lines
8.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, func, text
from ..db import get_db
from ..models.users import User
from passlib.hash import bcrypt
from ..config import JWT_SECRET, JWT_EXPIRE_MINUTES
import jwt
from datetime import datetime, timedelta, timezone
import re
from ..config import now_bj
router = APIRouter()
class LoginRequest(BaseModel):
username: str
password: str
class RegisterRequest(BaseModel):
username: str
email: str
password: str
fullName: str
async def _get_user_id(db: AsyncSession, username: str) -> int | None:
res = await db.execute(text("SELECT id FROM users WHERE username=:u LIMIT 1"), {"u": username})
row = res.first()
return row[0] if row else None
async def _get_role_id(db: AsyncSession, role_key: str) -> int | None:
res = await db.execute(text("SELECT id FROM roles WHERE role_key=:k LIMIT 1"), {"k": role_key})
row = res.first()
return row[0] if row else None
async def _ensure_observer_role(db: AsyncSession) -> int:
rid = await _get_role_id(db, "observer")
if rid is not None:
return rid
await db.execute(
text(
"INSERT INTO roles(role_name, role_key, description, is_system_role, created_at, updated_at) VALUES(:rn, :rk, :desc, TRUE, NOW(), NOW())"
),
{"rn": "观察员", "rk": "observer", "desc": "系统默认观察员角色"},
)
await db.commit()
rid2 = await _get_role_id(db, "observer")
if rid2 is None:
raise HTTPException(status_code=500, detail="role_init_failed")
return rid2
async def _map_user_role(db: AsyncSession, username: str, role_key: str) -> None:
uid = await _get_user_id(db, username)
if uid is None:
raise HTTPException(status_code=500, detail="user_not_found_after_register")
rid = await _get_role_id(db, role_key)
if rid is None:
if role_key == "observer":
rid = await _ensure_observer_role(db)
else:
raise HTTPException(status_code=400, detail="role_not_exist")
await db.execute(text("DELETE FROM user_role_mapping WHERE user_id=:uid"), {"uid": uid})
await db.execute(text("INSERT INTO user_role_mapping(user_id, role_id) VALUES(:uid, :rid)"), {"uid": uid, "rid": rid})
await db.commit()
async def _get_user_roles(db: AsyncSession, user_id: int) -> list[str]:
res = await db.execute(
text("SELECT r.role_key FROM roles r JOIN user_role_mapping urm ON r.id = urm.role_id WHERE urm.user_id = :uid"),
{"uid": user_id},
)
return [row[0] for row in res.all()]
async def _get_role_permissions(db: AsyncSession, role_keys: list[str]) -> list[str]:
if not role_keys:
return []
res = await db.execute(
text("""
SELECT DISTINCT p.permission_key
FROM permissions p
JOIN role_permission_mapping rpm ON p.id = rpm.permission_id
JOIN roles r ON rpm.role_id = r.id
WHERE r.role_key = ANY(:keys)
"""),
{"keys": role_keys},
)
return [row[0] for row in res.all()]
@router.post("/user/login")
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
demo = {"admin": "admin123", "ops": "ops123", "obs": "obs123"}
if req.username in demo and req.password == demo[req.username]:
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
token = jwt.encode({"sub": req.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
# 为 demo 账号获取角色和权限
uid = await _get_user_id(db, req.username)
roles = await _get_user_roles(db, uid) if uid else []
if not roles:
# 如果 DB 中没记录,给个默认
role_map = {"admin": ["admin"], "ops": ["operator"], "obs": ["observer"]}
roles = role_map.get(req.username, [])
permissions = await _get_role_permissions(db, roles)
return {
"ok": True,
"username": req.username,
"fullName": req.username,
"token": token,
"roles": roles,
"permissions": permissions
}
try:
result = await db.execute(select(User).where(User.username == req.username).limit(1))
user = result.scalars().first()
if not user:
raise HTTPException(status_code=401, detail="invalid_credentials")
if not user.is_active:
raise HTTPException(status_code=403, detail="inactive_user")
if not bcrypt.verify(req.password, user.password_hash):
raise HTTPException(status_code=401, detail="invalid_credentials")
await db.execute(
update(User).where(User.id == user.id).values(last_login=func.now(), updated_at=func.now())
)
await db.commit()
# 获取用户角色和权限
roles = await _get_user_roles(db, user.id)
permissions = await _get_role_permissions(db, roles)
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
token = jwt.encode({"sub": user.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
return {
"ok": True,
"username": user.username,
"fullName": user.full_name,
"token": token,
"roles": roles,
"permissions": permissions
}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="server_error")
@router.post("/user/register")
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
try:
errors: list[dict] = []
# 用户名校验3-50位字母开头支持字母/数字/下划线
if not req.username or not (3 <= len(req.username) <= 50):
errors.append({"field": "username", "code": "invalid_username", "message": "用户名长度需在3-50之间"})
elif not re.fullmatch(r"^[A-Za-z][A-Za-z0-9_]*$", req.username):
errors.append({"field": "username", "code": "invalid_username", "message": "用户名需以字母开头,仅支持字母、数字和下划线"})
# 邮箱校验
if not req.email or not re.fullmatch(r"^[^@\s]+@[^@\s]+\.[^@\s]+", req.email):
errors.append({"field": "email", "code": "invalid_email", "message": "邮箱格式不正确"})
# 密码校验:前端要求>=6位后端要求>=8位并包含复杂性。为了兼容性调优提示。
if not req.password or len(req.password) < 6:
errors.append({"field": "password", "code": "weak_password", "message": "密码长度至少为6位"})
elif len(req.password) < 8 or not re.search(r"[A-Z]", req.password) or not re.search(r"[a-z]", req.password) or not re.search(r"\d", req.password):
errors.append({"field": "password", "code": "weak_password", "message": "密码建议至少8位且包含大小写字母与数字"})
# 姓名校验
if not req.fullName or not (2 <= len(req.fullName) <= 100):
errors.append({"field": "fullName", "code": "invalid_full_name", "message": "姓名长度需在2-100之间"})
if errors:
raise HTTPException(status_code=400, detail={"errors": errors, "message": errors[0]["message"]})
# 检查唯一性
exists_username = await db.execute(select(User).where(User.username == req.username).limit(1))
if exists_username.scalars().first():
raise HTTPException(status_code=400, detail={"message": "该用户名已被注册", "code": "user_exists"})
exists_email = await db.execute(select(User.id).where(User.email == req.email).limit(1))
if exists_email.scalars().first():
raise HTTPException(status_code=400, detail={"message": "该邮箱已被绑定", "code": "email_exists"})
password_hash = bcrypt.hash(req.password)
user = User(
username=req.username,
email=req.email,
password_hash=password_hash,
full_name=req.fullName,
is_active=True,
last_login=None,
created_at=now_bj(),
updated_at=now_bj(),
)
db.add(user)
await db.flush()
await db.commit()
await _map_user_role(db, req.username, "observer")
permissions = await _get_role_permissions(db, ["observer"])
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
token = jwt.encode({"sub": user.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
return {
"ok": True,
"username": user.username,
"fullName": user.full_name,
"token": token,
"roles": ["observer"],
"permissions": permissions
}
except HTTPException:
raise
except Exception as e:
print(f"DEBUG: Database error: {str(e)}")
raise HTTPException(status_code=500, detail=f"server_error: {str(e)}")