用户-角色映射优化

pull/48/head
echo 5 months ago
parent 2eb858a743
commit eb935fa5ed

@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, func
from sqlalchemy import select, update, delete, func, text
from ..db import get_db
from ..models.users import User
from ..deps.auth import get_current_user
@ -33,6 +33,37 @@ def _status_to_active(status: str) -> bool:
def _active_to_status(active: bool) -> str:
return "enabled" if active else "disabled"
async def _get_user_id(db: AsyncSession, username: str) -> int | None:
res = await db.execute(text("SELECT id FROM users WHERE username=:u LIMIT 1"), {"u": username})
row = res.first()
return row[0] if row else None
async def _get_role_id(db: AsyncSession, role_key: str) -> int | None:
res = await db.execute(text("SELECT id FROM roles WHERE role_key=:k LIMIT 1"), {"k": role_key})
row = res.first()
return row[0] if row else None
async def _get_role_key(db: AsyncSession, username: str) -> str | None:
res = await db.execute(
text(
"SELECT r.role_key FROM roles r JOIN user_role_mapping m ON r.id=m.role_id JOIN users u ON u.id=m.user_id WHERE u.username=:u LIMIT 1"
),
{"u": username},
)
row = res.first()
return row[0] if row else None
async def _set_user_role(db: AsyncSession, username: str, role_key: str) -> bool:
uid = await _get_user_id(db, username)
if uid is None:
return False
rid = await _get_role_id(db, role_key)
if rid is None:
return False
await db.execute(text("DELETE FROM user_role_mapping WHERE user_id=:uid"), {"uid": uid})
await db.execute(text("INSERT INTO user_role_mapping(user_id, role_id) VALUES(:uid, :rid)"), {"uid": uid, "rid": rid})
await db.commit()
return True
def _role_or_default(username: str) -> str:
if username in ROLE_OVERRIDES:
@ -62,15 +93,17 @@ async def list_users(user=Depends(get_current_user), db: AsyncSession = Depends(
_require_admin(user)
result = await db.execute(select(User).limit(500))
rows = result.scalars().all()
users = [
{
"username": u.username,
"email": u.email,
"role": _role_or_default(u.username),
"status": _active_to_status(u.is_active),
}
for u in rows
]
users = []
for u in rows:
rk = await _get_role_key(db, u.username)
users.append(
{
"username": u.username,
"email": u.email,
"role": rk or "observer",
"status": _active_to_status(u.is_active),
}
)
return {"users": users}
except HTTPException:
raise
@ -117,7 +150,9 @@ async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db
db.add(user_obj)
await db.flush()
await db.commit()
ROLE_OVERRIDES[req.username] = req.role
ok = await _set_user_role(db, req.username, req.role)
if not ok:
raise HTTPException(status_code=400, detail={"errors": [{"field": "role", "message": "角色不存在"}]})
return {"ok": True}
except HTTPException:
raise
@ -141,7 +176,9 @@ async def update_user(username: str, req: UpdateUserRequest, user=Depends(get_cu
if req.role is not None:
if req.role not in {"admin", "operator", "observer"}:
raise HTTPException(status_code=400, detail={"errors": [{"field": "role", "message": "不允许的角色"}]})
ROLE_OVERRIDES[username] = req.role
ok = await _set_user_role(db, username, req.role)
if not ok:
raise HTTPException(status_code=400, detail={"errors": [{"field": "role", "message": "角色不存在"}]})
if updates:
updates["updated_at"] = func.now()
await db.execute(update(User).where(User.id == u.id).values(**updates))
@ -170,3 +207,28 @@ async def delete_user(username: str, user=Depends(get_current_user), db: AsyncSe
raise
except Exception:
raise HTTPException(status_code=500, detail="server_error")
@router.get("/users/with-roles")
async def list_users_with_roles(user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
try:
_require_admin(user)
res = await db.execute(
text(
"SELECT u.username,u.email,u.is_active,r.role_key FROM users u LEFT JOIN user_role_mapping m ON u.id=m.user_id LEFT JOIN roles r ON r.id=m.role_id LIMIT 500"
)
)
rows = res.all()
users = [
{
"username": r[0],
"email": r[1],
"role": r[3] or "observer",
"status": _active_to_status(r[2]),
}
for r in rows
]
return {"users": users}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="server_error")

Loading…
Cancel
Save