|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
|
from pydantic import BaseModel
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
from sqlalchemy import select, update, func, text
|
|
|
from ..db import get_db
|
|
|
from ..models.users import User
|
|
|
from passlib.hash import bcrypt
|
|
|
from ..config import JWT_SECRET, JWT_EXPIRE_MINUTES
|
|
|
import jwt
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
import re
|
|
|
from ..config import now_bj
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
|
username: str
|
|
|
password: str
|
|
|
|
|
|
class RegisterRequest(BaseModel):
|
|
|
username: str
|
|
|
email: str
|
|
|
password: str
|
|
|
fullName: str
|
|
|
|
|
|
async def _get_user_id(db: AsyncSession, username: str) -> int | None:
|
|
|
res = await db.execute(text("SELECT id FROM users WHERE username=:u LIMIT 1"), {"u": username})
|
|
|
row = res.first()
|
|
|
return row[0] if row else None
|
|
|
|
|
|
async def _get_role_id(db: AsyncSession, role_key: str) -> int | None:
|
|
|
res = await db.execute(text("SELECT id FROM roles WHERE role_key=:k LIMIT 1"), {"k": role_key})
|
|
|
row = res.first()
|
|
|
return row[0] if row else None
|
|
|
|
|
|
async def _ensure_observer_role(db: AsyncSession) -> int:
|
|
|
rid = await _get_role_id(db, "observer")
|
|
|
if rid is not None:
|
|
|
return rid
|
|
|
await db.execute(
|
|
|
text(
|
|
|
"INSERT INTO roles(role_name, role_key, description, is_system_role, created_at, updated_at) VALUES(:rn, :rk, :desc, TRUE, NOW(), NOW())"
|
|
|
),
|
|
|
{"rn": "观察员", "rk": "observer", "desc": "系统默认观察员角色"},
|
|
|
)
|
|
|
await db.commit()
|
|
|
rid2 = await _get_role_id(db, "observer")
|
|
|
if rid2 is None:
|
|
|
raise HTTPException(status_code=500, detail="role_init_failed")
|
|
|
return rid2
|
|
|
|
|
|
async def _map_user_role(db: AsyncSession, username: str, role_key: str) -> None:
|
|
|
uid = await _get_user_id(db, username)
|
|
|
if uid is None:
|
|
|
raise HTTPException(status_code=500, detail="user_not_found_after_register")
|
|
|
rid = await _get_role_id(db, role_key)
|
|
|
if rid is None:
|
|
|
if role_key == "observer":
|
|
|
rid = await _ensure_observer_role(db)
|
|
|
else:
|
|
|
raise HTTPException(status_code=400, detail="role_not_exist")
|
|
|
await db.execute(text("DELETE FROM user_role_mapping WHERE user_id=:uid"), {"uid": uid})
|
|
|
await db.execute(text("INSERT INTO user_role_mapping(user_id, role_id) VALUES(:uid, :rid)"), {"uid": uid, "rid": rid})
|
|
|
await db.commit()
|
|
|
|
|
|
async def _get_user_roles(db: AsyncSession, user_id: int) -> list[str]:
|
|
|
res = await db.execute(
|
|
|
text("SELECT r.role_key FROM roles r JOIN user_role_mapping urm ON r.id = urm.role_id WHERE urm.user_id = :uid"),
|
|
|
{"uid": user_id},
|
|
|
)
|
|
|
return [row[0] for row in res.all()]
|
|
|
|
|
|
async def _get_role_permissions(db: AsyncSession, role_keys: list[str]) -> list[str]:
|
|
|
if not role_keys:
|
|
|
return []
|
|
|
res = await db.execute(
|
|
|
text("""
|
|
|
SELECT DISTINCT p.permission_key
|
|
|
FROM permissions p
|
|
|
JOIN role_permission_mapping rpm ON p.id = rpm.permission_id
|
|
|
JOIN roles r ON rpm.role_id = r.id
|
|
|
WHERE r.role_key = ANY(:keys)
|
|
|
"""),
|
|
|
{"keys": role_keys},
|
|
|
)
|
|
|
return [row[0] for row in res.all()]
|
|
|
|
|
|
@router.post("/user/login")
|
|
|
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|
|
demo = {"admin": "admin123", "ops": "ops123", "obs": "obs123"}
|
|
|
if req.username in demo and req.password == demo[req.username]:
|
|
|
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
|
|
|
token = jwt.encode({"sub": req.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
|
|
|
|
|
|
# 为 demo 账号获取角色和权限
|
|
|
uid = await _get_user_id(db, req.username)
|
|
|
roles = await _get_user_roles(db, uid) if uid else []
|
|
|
if not roles:
|
|
|
# 如果 DB 中没记录,给个默认
|
|
|
role_map = {"admin": ["admin"], "ops": ["operator"], "obs": ["observer"]}
|
|
|
roles = role_map.get(req.username, [])
|
|
|
|
|
|
permissions = await _get_role_permissions(db, roles)
|
|
|
|
|
|
return {
|
|
|
"ok": True,
|
|
|
"username": req.username,
|
|
|
"fullName": req.username,
|
|
|
"token": token,
|
|
|
"roles": roles,
|
|
|
"permissions": permissions
|
|
|
}
|
|
|
try:
|
|
|
result = await db.execute(select(User).where(User.username == req.username).limit(1))
|
|
|
user = result.scalars().first()
|
|
|
if not user:
|
|
|
raise HTTPException(status_code=401, detail="invalid_credentials")
|
|
|
if not user.is_active:
|
|
|
raise HTTPException(status_code=403, detail="inactive_user")
|
|
|
if not bcrypt.verify(req.password, user.password_hash):
|
|
|
raise HTTPException(status_code=401, detail="invalid_credentials")
|
|
|
await db.execute(
|
|
|
update(User).where(User.id == user.id).values(last_login=func.now(), updated_at=func.now())
|
|
|
)
|
|
|
await db.commit()
|
|
|
|
|
|
# 获取用户角色和权限
|
|
|
roles = await _get_user_roles(db, user.id)
|
|
|
permissions = await _get_role_permissions(db, roles)
|
|
|
|
|
|
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
|
|
|
token = jwt.encode({"sub": user.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
|
|
|
return {
|
|
|
"ok": True,
|
|
|
"username": user.username,
|
|
|
"fullName": user.full_name,
|
|
|
"token": token,
|
|
|
"roles": roles,
|
|
|
"permissions": permissions
|
|
|
}
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception:
|
|
|
raise HTTPException(status_code=500, detail="server_error")
|
|
|
|
|
|
@router.post("/user/register")
|
|
|
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
|
|
try:
|
|
|
errors: list[dict] = []
|
|
|
# 用户名校验:3-50位,字母开头,支持字母/数字/下划线
|
|
|
if not req.username or not (3 <= len(req.username) <= 50):
|
|
|
errors.append({"field": "username", "code": "invalid_username", "message": "用户名长度需在3-50之间"})
|
|
|
elif not re.fullmatch(r"^[A-Za-z][A-Za-z0-9_]*$", req.username):
|
|
|
errors.append({"field": "username", "code": "invalid_username", "message": "用户名需以字母开头,仅支持字母、数字和下划线"})
|
|
|
|
|
|
# 邮箱校验
|
|
|
if not req.email or not re.fullmatch(r"^[^@\s]+@[^@\s]+\.[^@\s]+", req.email):
|
|
|
errors.append({"field": "email", "code": "invalid_email", "message": "邮箱格式不正确"})
|
|
|
|
|
|
# 密码校验:前端要求>=6位,后端要求>=8位并包含复杂性。为了兼容性调优提示。
|
|
|
if not req.password or len(req.password) < 6:
|
|
|
errors.append({"field": "password", "code": "weak_password", "message": "密码长度至少为6位"})
|
|
|
elif len(req.password) < 8 or not re.search(r"[A-Z]", req.password) or not re.search(r"[a-z]", req.password) or not re.search(r"\d", req.password):
|
|
|
errors.append({"field": "password", "code": "weak_password", "message": "密码建议至少8位,且包含大小写字母与数字"})
|
|
|
|
|
|
# 姓名校验
|
|
|
if not req.fullName or not (2 <= len(req.fullName) <= 100):
|
|
|
errors.append({"field": "fullName", "code": "invalid_full_name", "message": "姓名长度需在2-100之间"})
|
|
|
|
|
|
if errors:
|
|
|
raise HTTPException(status_code=400, detail={"errors": errors, "message": errors[0]["message"]})
|
|
|
|
|
|
# 检查唯一性
|
|
|
exists_username = await db.execute(select(User).where(User.username == req.username).limit(1))
|
|
|
if exists_username.scalars().first():
|
|
|
raise HTTPException(status_code=400, detail={"message": "该用户名已被注册", "code": "user_exists"})
|
|
|
|
|
|
exists_email = await db.execute(select(User.id).where(User.email == req.email).limit(1))
|
|
|
if exists_email.scalars().first():
|
|
|
raise HTTPException(status_code=400, detail={"message": "该邮箱已被绑定", "code": "email_exists"})
|
|
|
password_hash = bcrypt.hash(req.password)
|
|
|
user = User(
|
|
|
username=req.username,
|
|
|
email=req.email,
|
|
|
password_hash=password_hash,
|
|
|
full_name=req.fullName,
|
|
|
is_active=True,
|
|
|
last_login=None,
|
|
|
created_at=now_bj(),
|
|
|
updated_at=now_bj(),
|
|
|
)
|
|
|
db.add(user)
|
|
|
await db.flush()
|
|
|
await db.commit()
|
|
|
await _map_user_role(db, req.username, "observer")
|
|
|
permissions = await _get_role_permissions(db, ["observer"])
|
|
|
|
|
|
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
|
|
|
token = jwt.encode({"sub": user.username, "exp": exp}, JWT_SECRET, algorithm="HS256")
|
|
|
return {
|
|
|
"ok": True,
|
|
|
"username": user.username,
|
|
|
"fullName": user.full_name,
|
|
|
"token": token,
|
|
|
"roles": ["observer"],
|
|
|
"permissions": permissions
|
|
|
}
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
print(f"DEBUG: Database error: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=f"server_error: {str(e)}")
|