develop
echo 3 months ago
parent 8715b46e5d
commit ed368d465c

@ -8,22 +8,15 @@ import json
async def run_diagnose_and_repair(db: AsyncSession, operator: str, context: Dict[str, Any], auto: bool = True, max_steps: int = 3, model: Optional[str] = None) -> Dict[str, Any]:
"""
单智能体根据日志上下文诊断并自动修复Function Calling 模式
"""单智能体根据日志上下文诊断并自动修复Function Calling
参数:
- db: 异步数据库会话
- operator: 操作人用户名
- context: 包含集群/节点/日志片段等上下文信息的字典
- auto: 是否允许智能体自动执行修复工具默认 True
- max_steps: 允许工具调用的最大迭代步数防止死循环或成本超支
- model: 使用的 LLM 模型名称
返回:
- Dict 包含根因分析执行动作列表及其结果剩余风险评估
- context包含 cluster/node/logs 等关键信息
- auto是否允许自动执行工具默认允许
- max_steps最多工具调用步数
- model指定的模型名称
返回根因动作列表与结果剩余风险
"""
llm = LLMClient()
# 构造初始对话消息,设定智能体身份为运维专家
messages: List[Dict[str, Any]] = [
{
"role": "system",
@ -34,36 +27,25 @@ async def run_diagnose_and_repair(db: AsyncSession, operator: str, context: Dict
"content": f"上下文: {context}",
},
]
# 获取可调用的工具 Schema 定义
tools = openai_tools_schema()
actions: List[Dict[str, Any]] = []
root_cause = None
residual_risk = "medium"
# 循环执行诊断步骤,直到达到最大步数或 LLM 给出最终结论
for step in range(max_steps):
# 向 LLM 发起对话请求(含工具定义)
resp = await llm.chat(messages, tools=tools, stream=False, model=model)
choice = (resp.get("choices") or [{}])[0]
msg = choice.get("message", {})
tool_calls = msg.get("tool_calls") or []
# 如果 LLM 没有提出调用工具的要求,说明已经给出了最终的诊断文本
if not tool_calls:
root_cause = msg.get("content")
break
# 如果 auto 为 False即使 LLM 建议调用工具,也不实际执行
if not auto:
break
# 处理 LLM 提出的每一个工具调用请求
for tc in tool_calls:
fn = (tc.get("function") or {})
name = fn.get("name")
raw_args = fn.get("arguments") or {}
# 解析工具调用的参数(可能是 JSON 字符串或字典)
if isinstance(raw_args, str):
try:
args = json.loads(raw_args)
@ -73,9 +55,7 @@ async def run_diagnose_and_repair(db: AsyncSession, operator: str, context: Dict
args = raw_args
else:
args = {}
result: Dict[str, Any]
# 根据工具名称分发到对应的 service 工具函数
if name == "read_log":
result = await tool_read_log(db, operator, args.get("node"), args.get("path"), int(args.get("lines", 200)), args.get("pattern"), args.get("sshUser"))
elif name == "read_cluster_log":
@ -113,10 +93,6 @@ async def run_diagnose_and_repair(db: AsyncSession, operator: str, context: Dict
result = await tool_stop_cluster(db, operator, args.get("cluster_uuid"))
else:
result = {"error": "unknown_tool"}
# 记录执行动作及其返回结果
actions.append({"name": name, "args": args, "result": result})
# 将工具执行结果回填到消息历史中,以便 LLM 进行下一轮分析
messages.append({"role": "tool", "content": str(result), "name": name})
return {"rootCause": root_cause, "actions": actions, "residualRisk": residual_risk}

@ -5,21 +5,18 @@ from typing import Dict, Tuple
from datetime import datetime
from zoneinfo import ZoneInfo
# 加载 .env 文件中的环境变量
load_dotenv()
# 时区配置
# Timezone Configuration
APP_TIMEZONE = os.getenv("APP_TIMEZONE", "Asia/Shanghai")
BJ_TZ = ZoneInfo(APP_TIMEZONE)
def now_bj() -> datetime:
"""返回当前北京时间的 datetime 对象"""
return datetime.now(BJ_TZ)
# 数据库配置
# Database Configuration
_db_url = os.getenv("DATABASE_URL")
if not _db_url:
# 如果没有直接提供 DATABASE_URL则通过各个参数拼接
_host = os.getenv("DB_HOST")
_port = os.getenv("DB_PORT")
_name = os.getenv("DB_NAME")
@ -28,24 +25,21 @@ if not _db_url:
if all([_host, _port, _name, _user, _password]):
_db_url = f"postgresql+asyncpg://{_user}:{_password}@{_host}:{_port}/{_name}"
else:
# 默认开发环境数据库地址
_db_url = "postgresql+asyncpg://postgres:password@localhost:5432/hadoop_fault_db"
# 异步数据库连接地址
DATABASE_URL = _db_url
# 同步数据库连接地址(某些脚本或工具可能需要)
SYNC_DATABASE_URL = _db_url.replace("postgresql+asyncpg://", "postgresql://")
# JWT (JSON Web Token) 认证配置
# JWT Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "dev-secret")
JWT_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "60"))
# SSH 连接配置
# SSH Configuration
SSH_PORT = int(os.getenv("SSH_PORT", "22"))
SSH_TIMEOUT = int(os.getenv("SSH_TIMEOUT", "10"))
ssh_port = SSH_PORT
ssh_timeout = SSH_TIMEOUT
# Hadoop 日志目录路径
LOG_DIR = os.getenv("HADOOP_LOG_DIR", "/usr/local/hadoop/logs")

@ -1,21 +1,15 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from .config import DATABASE_URL, APP_TIMEZONE
# 创建异步 SQLAlchemy 引擎
engine = create_async_engine(
DATABASE_URL,
echo=False, # 是否打印 SQL 语句(生产环境通常设为 False
pool_pre_ping=True, # 在从连接池取出连接前进行 ping 测试,确保连接有效
connect_args={"server_settings": {"timezone": APP_TIMEZONE}}, # 设置数据库会话时区
echo=False,
pool_pre_ping=True,
connect_args={"server_settings": {"timezone": APP_TIMEZONE}},
)
# 创建异步会话工厂
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async def get_db() -> AsyncSession:
"""
依赖注入函数获取一个异步数据库会话
使用 yield 确保在请求结束后自动关闭会话
"""
"""获取一个异步数据库会话,用于依赖注入。"""
async with SessionLocal() as session:
yield session

@ -8,40 +8,27 @@ import jwt
from typing import List
async def get_current_user(authorization: str | None = Header(None), db: AsyncSession = Depends(get_db)):
"""
FastAPI 依赖项获取并验证当前登录用户
1. 检查 Authorization Header 是否包含 Bearer Token
2. 解码并验证 JWT Token
3. 从数据库中查找用户信息
4. 聚合用户的权限列表包括数据库配置权限和硬编码的预设权限
"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="not_authenticated")
token = authorization[7:]
try:
# 解码 JWT
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
username = payload.get("sub")
if not username:
raise HTTPException(status_code=401, detail="invalid_token")
# 查询数据库中的用户
result = await db.execute(select(User).where(User.username == username).limit(1))
user = result.scalars().first()
if not user:
# 特殊逻辑:如果用户不在 DB 中(可能是 mock/demo 账号),创建一个临时用户字典
# 如果是 demo 用户不在 DB 中,创建一个临时用户字典
user_dict = {"username": username, "id": None, "is_active": True}
else:
if not user.is_active:
raise HTTPException(status_code=403, detail="inactive_user")
user_dict = {"username": user.username, "id": user.id, "is_active": user.is_active}
# 核心逻辑:获取用户的权限列表
# 1. 查询用户所属角色的关联权限
# 2. 联合查询为特定用户名admin/ops/obs硬编码基础权限
# 3. 联合查询:为特定用户名关联特定角色及其权限
# 获取权限列表
perms_res = await db.execute(
text("""
SELECT DISTINCT p.permission_key
@ -51,7 +38,7 @@ async def get_current_user(authorization: str | None = Header(None), db: AsyncSe
JOIN users u ON urm.user_id = u.id
WHERE u.username = :u
UNION
-- 兼容预设角色及其对应的基本权限硬编码逻辑
-- 兼容预设角色及其对应的基本权限
SELECT 'cluster:register' AS permission_key
WHERE (:u = 'admin' OR :u = 'ops' OR :u = 'obs')
UNION
@ -64,7 +51,7 @@ async def get_current_user(authorization: str | None = Header(None), db: AsyncSe
SELECT 'cluster:stop' AS permission_key
WHERE (:u = 'admin' OR :u = 'ops')
UNION
-- 兼容 demo 账号如果不在 DB 的更多角色关联权限
-- 兼容 demo 账号如果不在 DB 的更多权限
SELECT DISTINCT p.permission_key
FROM permissions p
JOIN role_permission_mapping rpm ON p.id = rpm.permission_id
@ -87,16 +74,10 @@ async def get_current_user(authorization: str | None = Header(None), db: AsyncSe
raise HTTPException(status_code=500, detail="auth_error")
class PermissionChecker:
"""
权限校验器类可用作 FastAPI 路由依赖项
用法Depends(PermissionChecker(["cluster:start"]))
"""
def __init__(self, required_permissions: List[str]):
self.required_permissions = required_permissions
def __call__(self, user=Depends(get_current_user)):
"""校验当前用户是否拥有所有必需的权限。"""
user_perms = user.get("permissions", [])
for perm in self.required_permissions:
if perm not in user_perms:

@ -13,13 +13,9 @@ import asyncio
from .config import BJ_TZ, DATABASE_URL, APP_TIMEZONE
class LogCollector:
"""
Hadoop 集群实时日志采集器
通过 SSH 增量读取远程节点的日志文件并解析后存入数据库
"""
"""Real-time log collector for Hadoop cluster"""
def __init__(self):
# 正在运行的采集线程字典 {collector_id: thread}
self.collectors: Dict[str, threading.Thread] = {}
self.is_running: bool = False
self.collection_interval: int = 5 # 默认采集间隔,单位:秒
@ -30,21 +26,21 @@ class LogCollector:
self._cluster_name_cache: Dict[str, str] = {}
self._targets: Dict[str, str] = {}
self._line_counts: Dict[str, int] = {}
# 每次从远程拉取的最大字节数,防止内存溢出
self.max_bytes_per_pull: int = 256 * 1024
def start_collection(self, node_name: str, log_type: str, ip: Optional[str] = None, interval: Optional[int] = None) -> bool:
"""启动特定节点和日志类型的实时采集任务"""
"""Start real-time log collection for a specific node and log type"""
collector_id = f"{node_name}_{log_type}"
if interval is not None:
self._intervals[collector_id] = max(1, int(interval))
# 如果该采集任务已在运行,则跳过
if collector_id in self.collectors and self.collectors[collector_id].is_alive():
print(f"采集器 {collector_id} 已经在运行中")
print(f"Collector {collector_id} is already running")
return False
# 创建并启动新的采集线程(守护线程)
# Start even if log file not yet exists; collector will self-check in loop
# Create a new collector thread
collector_thread = threading.Thread(
target=self._collect_logs,
args=(node_name, log_type, ip),
@ -54,44 +50,48 @@ class LogCollector:
self.collectors[collector_id] = collector_thread
collector_thread.start()
print(f"已启动采集器 {collector_id}")
print(f"Started collector {collector_id}")
return True
def stop_collection(self, node_name: str, log_type: str):
"""停止特定节点和日志类型的采集任务"""
"""Stop log collection for a specific node and log type"""
collector_id = f"{node_name}_{log_type}"
if collector_id in self.collectors:
# 由于是守护线程,从跟踪字典中移除即可,循环内部会通过检查字典来退出
# Threads are daemon, so they will exit when main process exits
# We just remove it from our tracking
del self.collectors[collector_id]
self._intervals.pop(collector_id, None)
print(f"已停止采集器 {collector_id}")
print(f"Stopped collector {collector_id}")
else:
print(f"采集器 {collector_id} 未在运行")
print(f"Collector {collector_id} is not running")
def stop_all_collections(self):
"""停止所有正在运行的日志采集任务"""
"""Stop all log collections"""
for collector_id in list(self.collectors.keys()):
self.stop_collection(*collector_id.split("_"))
def _parse_log_line(self, line: str, node_name: str, log_type: str):
"""解析单行日志,提取时间戳、日志级别、消息内容等字段"""
# 默认值
"""Parse a single log line and return a dictionary of log fields"""
# Extract timestamp from the log line (format: [2023-12-17 10:00:00,123])
timestamp = None
log_level = "INFO"
log_level = "INFO" # Default log level
message = line
exception = None
# 尝试解析 Hadoop 标准日志格式: [YYYY-MM-DD HH:MM:SS,mmm]
# Simple log parsing logic
if line.startswith('['):
# Extract timestamp
timestamp_end = line.find(']', 1)
if timestamp_end > 0:
timestamp_str = line[1:timestamp_end]
try:
timestamp = datetime.datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S,%f").replace(tzinfo=BJ_TZ)
except ValueError:
# If parsing fails, use current time
timestamp = datetime.datetime.now(BJ_TZ)
# 提取日志级别关键字
# Extract log level
log_levels = ["ERROR", "WARN", "INFO", "DEBUG", "TRACE"]
for level in log_levels:
if f" {level} " in line:
@ -108,12 +108,12 @@ class LogCollector:
}
async def _save_log_to_db(self, log_data: Dict, collector_id: str | None = None):
"""将单条日志数据异步保存到数据库"""
"""Save log data to database"""
try:
session_local = self._session_locals.get(collector_id) if collector_id else None
async with (session_local() if session_local else SessionLocal()) as session:
# 获取集群名称
host = log_data["host"]
# 优先从缓存获取集群名称
cluster_name = self._cluster_name_cache.get(host)
if not cluster_name:
cluster_res = await session.execute(text("""
@ -126,7 +126,7 @@ class LogCollector:
cluster_name = cluster_row[0] if cluster_row else "default_cluster"
self._cluster_name_cache[host] = cluster_name
# 创建数据库模型实例
# Create HadoopLog instance
hadoop_log = HadoopLog(
log_time=log_data["timestamp"],
node_host=log_data["host"],
@ -135,13 +135,14 @@ class LogCollector:
cluster_name=cluster_name
)
# Add to session and commit
session.add(hadoop_log)
await session.commit()
except Exception as e:
print(f"保存日志到数据库时出错: {e}")
print(f"Error saving log to database: {e}")
async def _save_logs_to_db_batch(self, logs: List[Dict], collector_id: str | None = None):
"""批量将日志数据异步保存到数据库,提高写入性能"""
"""Save a batch of logs to database in one transaction"""
try:
session_local = self._session_locals.get(collector_id) if collector_id else None
async with (session_local() if session_local else SessionLocal()) as session:
@ -170,7 +171,7 @@ class LogCollector:
session.add_all(objs)
await session.commit()
except Exception as e:
print(f"批量保存日志时出错: {e}")
print(f"Error batch saving logs: {e}")
def _collect_logs(self, node_name: str, log_type: str, ip: str):
"""Internal method to collect logs continuously"""

@ -3,14 +3,11 @@ from .config import LOG_DIR
from .ssh_utils import ssh_manager
class LogReader:
"""
Hadoop 日志读取类负责从远程集群节点读取并处理 Hadoop 相关日志
"""
"""Log Reader for Hadoop cluster nodes"""
def __init__(self):
self.log_dir = LOG_DIR
self._node_log_dir: Dict[str, str] = {}
# 预定义的 Hadoop 日志可能存在的候选路径列表
self._candidates = [
"/usr/local/hadoop/logs",
"/opt/hadoop/logs",
@ -22,8 +19,8 @@ class LogReader:
]
def get_log_file_path(self, node_name: str, log_type: str) -> str:
"""根据节点名和日志类型生成默认的日志文件路径"""
# 将日志类型映射到实际的文件名前缀
"""Generate log file path based on node name and log type"""
# Map log type to actual log file name
log_file_map = {
"namenode": "hadoop-hadoop-namenode",
"datanode": "hadoop-hadoop-datanode",
@ -32,30 +29,29 @@ class LogReader:
"historyserver": "hadoop-hadoop-historyserver"
}
# 获取基础文件名
# Get the base log file name
base_name = log_file_map.get(log_type.lower(), log_type.lower())
# 生成完整的日志文件路径
# Generate full log file path
return f"{self.log_dir}/{base_name}-{node_name.replace('_', '')}.log"
def read_log(self, node_name: str, log_type: str, ip: str) -> str:
"""从特定节点读取日志内容"""
# 探测该节点上实际可用的日志目录
"""Read log from a specific node"""
# Ensure working log dir
self.find_working_log_dir(node_name, ip)
# 获取可能的日志文件路径列表
paths = self.get_log_file_paths(node_name, log_type)
# 获取 SSH 连接
# Get SSH connection
ssh_client = ssh_manager.get_connection(node_name, ip=ip)
# 尝试从候选路径中读取文件
# Read log file content
# try direct candidates
for p in paths:
out, err = ssh_client.execute_command(f"ls -la {p} 2>/dev/null")
if not err and out.strip():
out, err = ssh_client.execute_command(f"cat {p} 2>/dev/null")
if not err:
return out
# 如果直接路径不可用,尝试列出目录并根据名称匹配
# resolve by directory listing
base_dir = self._node_log_dir.get(node_name, self.log_dir)
out, err = ssh_client.execute_command(f"ls -la {base_dir} 2>/dev/null")
if not err and out.strip():
@ -64,50 +60,50 @@ class LogReader:
if parts:
fn = parts[-1]
lf = fn.lower()
# 匹配包含日志类型和节点名,且以 .log 或 .out 结尾的文件
if log_type in lf and node_name in lf and (lf.endswith(".log") or lf.endswith(".out") or lf.endswith(".out.1")):
out2, err2 = ssh_client.execute_command(f"cat {base_dir}/{fn} 2>/dev/null")
if not err2:
return out2
raise FileNotFoundError("未找到匹配的日志文件")
raise FileNotFoundError("No such file")
def read_all_nodes_log(self, nodes: List[Dict[str, str]], log_type: str) -> Dict[str, str]:
"""批量读取所有节点的指定类型日志"""
"""Read log from all nodes"""
logs = {}
for node in nodes:
node_name = node['name']
ip = node.get('ip')
if not ip:
logs[node_name] = "错误:未找到 IP 地址"
logs[node_name] = "Error: IP address not found"
continue
try:
logs[node_name] = self.read_log(node_name, log_type, ip)
except Exception as e:
logs[node_name] = f"读取日志失败: {str(e)}"
logs[node_name] = f"Error reading log: {str(e)}"
return logs
def filter_log_by_date(self, log_content: str, start_date: str, end_date: str) -> str:
"""根据日期范围过滤日志内容"""
"""Filter log content by date range"""
filtered_lines = []
for line in log_content.splitlines():
# 检查行首是否符合 [YYYY-MM-DD HH:MM:SS,mmm] 格式
# Check if line contains date in the format [YYYY-MM-DD HH:MM:SS,mmm]
if line.startswith('['):
# 提取日期部分 YYYY-MM-DD
date_str = line[1:11]
# Extract date part
date_str = line[1:11] # Get YYYY-MM-DD part
if start_date <= date_str <= end_date:
filtered_lines.append(line)
return '\n'.join(filtered_lines)
def get_log_files_list(self, node_name: str, ip: Optional[str] = None) -> List[str]:
"""获取指定节点上所有可用的日志文件列表"""
"""Get list of log files on a specific node"""
# Ensure working log dir
if ip:
self.find_working_log_dir(node_name, ip)
ssh_client = ssh_manager.get_connection(node_name, ip=ip)
# 从所有候选目录中查找可用的日志文件
# Execute command to list log files from available directories
dirs = [self._node_log_dir.get(node_name, self.log_dir)] + self._candidates
stdout = ""
for d in dirs:
@ -116,25 +112,30 @@ class LogReader:
stdout = out
self._node_log_dir[node_name] = d
break
stderr = ""
# Parse log files from output
log_files = []
if stdout.strip():
if not stderr and stdout.strip():
for line in stdout.splitlines():
name = line.strip()
# 仅收集日志或输出文件
if name.endswith(".log") or name.endswith(".out") or name.endswith(".out.1"):
log_files.append(name)
return log_files
def check_log_file_exists(self, node_name: str, log_type: str, ip: Optional[str] = None) -> bool:
"""检查特定类型的日志文件是否存在于节点上"""
"""Check if log file exists on a specific node"""
# Ensure working log dir
if ip:
self.find_working_log_dir(node_name, ip)
paths = self.get_log_file_paths(node_name, log_type)
# Get SSH connection
ssh_client = ssh_manager.get_connection(node_name, ip=ip)
try:
# Execute command to check if file exists
for p in paths:
stdout, stderr = ssh_client.execute_command(f"ls -la {p} 2>/dev/null")
if not stderr and stdout.strip():
@ -150,13 +151,15 @@ class LogReader:
return True
return False
except Exception as e:
print(f"检查日志文件是否存在时出错: {e}")
print(f"Error checking log file existence: {e}")
return False
def get_node_services(self, node_name: str) -> List[str]:
"""通过分析日志文件名来推断节点上运行的服务"""
"""Get list of running services on a node based on log files"""
# Get all log files
log_files = self.get_log_files_list(node_name)
# Extract service types from log file names
services = []
for log_file in log_files:
if "namenode" in log_file:
@ -170,32 +173,27 @@ class LogReader:
elif "secondarynamenode" in log_file:
services.append("secondarynamenode")
# Remove duplicates
return list(set(services))
def find_working_log_dir(self, node_name: str, ip: str) -> str:
"""在远程节点上探测并设置实际可用的日志目录"""
"""Detect a working log directory on remote node and set it"""
ssh_client = ssh_manager.get_connection(node_name, ip=ip)
# try current
current = self._node_log_dir.get(node_name, self.log_dir)
# 尝试当前已知的目录
stdout, stderr = ssh_client.execute_command(f"ls -la {current}")
if not stderr and stdout.strip():
self._node_log_dir[node_name] = current
return current
# 依次尝试候选目录
for d in [current] + self._candidates:
stdout, stderr = ssh_client.execute_command(f"ls -la {d} 2>/dev/null")
if not stderr and stdout.strip():
self._node_log_dir[node_name] = d
return d
# 默认回退
self._node_log_dir[node_name] = self.log_dir
return self._node_log_dir[node_name]
def get_log_file_paths(self, node_name: str, log_type: str) -> List[str]:
"""获取可能的日志文件全路径列表"""
base_dir = self._node_log_dir.get(node_name, self.log_dir)
base = f"{base_dir}/hadoop-hadoop-{log_type}-{node_name}"
return [f"{base}.log", f"{base}.out", f"{base}.out.1"]

@ -2,21 +2,18 @@ from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
# 导入各个功能模块的路由
from .routers import auth, health, secure, users, clusters, nodes, metrics, faults, ops, ai, hadoop_logs, sys_exec_logs, hadoop_exec_logs
import os
# 初始化 FastAPI 应用,设置标题和版本号
app = FastAPI(title="Hadoop Fault Detecting API", version="v1")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
异常处理器Pydantic 校验错误转换为前端更易解析的统一格式
Pydantic 校验错误转换为前端更易解析的格式
"""
errors = []
for error in exc.errors():
# 获取发生错误的字段名
field = error.get("loc")[-1] if error.get("loc") else "unknown"
msg = error.get("msg")
errors.append({
@ -30,27 +27,25 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
content={"detail": {"errors": errors, "message": "请求参数校验失败"}}
)
# 配置跨域资源共享 (CORS) 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源(生产环境建议限制特定域名)
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头
allow_methods=["*"],
allow_headers=["*"],
)
# 注册各个功能模块的路由到应用中,并统一添加 /api/v1 前缀
app.include_router(health.router, prefix="/api/v1") # 健康检查
app.include_router(auth.router, prefix="/api/v1") # 用户认证
app.include_router(secure.router, prefix="/api/v1") # 安全相关
app.include_router(clusters.router, prefix="/api/v1") # 集群管理
app.include_router(nodes.router, prefix="/api/v1") # 节点管理
app.include_router(metrics.router, prefix="/api/v1") # 指标监控
app.include_router(users.router, prefix="/api/v1") # 用户管理
app.include_router(hadoop_logs.router, prefix="/api/v1") # Hadoop 日志查看
app.include_router(faults.router, prefix="/api/v1") # 故障记录
app.include_router(hadoop_exec_logs.router, prefix="/api/v1")# Hadoop 执行日志
app.include_router(ops.router, prefix="/api/v1") # 运维操作
app.include_router(ai.router, prefix="/api/v1") # AI 智能诊断
app.include_router(sys_exec_logs.router, prefix="/api/v1") # 系统执行日志
app.include_router(health.router, prefix="/api/v1")
app.include_router(auth.router, prefix="/api/v1")
app.include_router(secure.router, prefix="/api/v1")
app.include_router(clusters.router, prefix="/api/v1")
app.include_router(nodes.router, prefix="/api/v1")
app.include_router(metrics.router, prefix="/api/v1")
app.include_router(users.router, prefix="/api/v1")
app.include_router(hadoop_logs.router, prefix="/api/v1")
app.include_router(faults.router, prefix="/api/v1")
app.include_router(hadoop_exec_logs.router, prefix="/api/v1")
app.include_router(ops.router, prefix="/api/v1")
app.include_router(ai.router, prefix="/api/v1")
app.include_router(sys_exec_logs.router, prefix="/api/v1")

@ -12,75 +12,43 @@ import asyncio
from .config import BJ_TZ
class MetricsCollector:
"""
节点指标采集器
负责通过 SSH 定期从远程节点采集 CPU 和内存使用率并更新到数据库
支持多线程并行采集并提供错误追踪和采集间隔控制
"""
def __init__(self):
# 正在运行的采集线程字典 {hostname: thread},用于管理各个节点的采集任务
self.collectors: Dict[str, threading.Thread] = {}
# 默认采集间隔(秒)
self.collection_interval: int = 5
# 记录每个节点最近一次采集时发生的错误
self.last_errors: Dict[str, str] = {}
# 数据库表列名缓存,用于动态检查字段是否存在
self._columns_cache: Dict[str, set] = {}
# 标记集群平均指标是否已初始化(当前逻辑暂未深度使用)
self._cluster_avg_inited: bool = False
def set_collection_interval(self, interval: int):
"""
设置采集间隔时间
:param interval: 间隔秒数最小为 1
"""
self.collection_interval = max(1, interval)
def get_collectors_status(self) -> Dict[str, bool]:
"""
获取所有采集任务的运行状态
:return: 字典 {节点名: 是否正在运行}
"""
status = {}
for cid, t in self.collectors.items():
status[cid] = t.is_alive()
return status
def get_errors(self) -> Dict[str, str]:
"""获取最近发生的采集错误信息字典"""
return dict(self.last_errors)
def stop_all(self):
"""停止系统中所有正在运行的指标采集任务"""
for cid in list(self.collectors.keys()):
self.stop(cid)
def stop(self, collector_id: str):
"""
停止特定节点的采集任务并清理相关状态
:param collector_id: 通常为节点的主机名
"""
if collector_id in self.collectors:
del self.collectors[collector_id]
if collector_id in self.last_errors:
del self.last_errors[collector_id]
def start_for_nodes(self, nodes: List[Tuple[int, str, str, int]], interval: Optional[int] = None) -> Tuple[int, List[str]]:
"""
为指定的一组节点批量启动指标采集任务
:param nodes: 节点信息列表 [(node_id, hostname, ip, cluster_id), ...]
:param interval: 可选设置新的全局采集间隔
:return: (成功启动的数量, 启动成功的节点名列表)
"""
if interval:
self.set_collection_interval(interval)
started: List[str] = []
for nid, hn, ip, cid in nodes:
cid_str = f"{hn}"
# 如果该节点的采集任务已在运行,则跳过
if cid_str in self.collectors and self.collectors[cid_str].is_alive():
continue
# 为每个节点创建独立的守护线程进行周期性采集
t = threading.Thread(target=self._collect_node_metrics, args=(nid, hn, ip, cid), name=f"metrics_{hn}", daemon=True)
self.collectors[cid_str] = t
t.start()
@ -88,18 +56,9 @@ class MetricsCollector:
return len(started), started
def _read_cpu_mem(self, node_name: str, ip: str) -> Tuple[float, float]:
"""
通过 SSH 在远程节点上执行命令并解析 CPU 和内存使用率
CPU 计算逻辑通过读取 /proc/stat 两次采样间隔 0.5s计算非空闲时间占比
内存计算逻辑通过读取 /proc/meminfo 计算 (Total - Available) / Total
"""
ssh_client = ssh_manager.get_connection(node_name, ip=ip)
# --- 采集 CPU 使用率 ---
# 第一次采样
out1, err1 = ssh_client.execute_command("cat /proc/stat | head -n 1")
_time.sleep(0.5)
# 第二次采样
out2, err2 = ssh_client.execute_command("cat /proc/stat | head -n 1")
cpu_pct = 0.0
if not err1 and not err2 and out1.strip() and out2.strip():
@ -107,16 +66,12 @@ class MetricsCollector:
p2 = out2.strip().split()
v1 = [int(x) for x in p1[1:]]
v2 = [int(x) for x in p2[1:]]
# proc/stat 各字段含义user, nice, system, idle, iowait, irq, softirq, steal
get1 = lambda i: (v1[i] if i < len(v1) else 0)
get2 = lambda i: (v2[i] if i < len(v2) else 0)
# 计算两次采样间的空闲时间和总时间差
idle = (get2(3) + get2(4)) - (get1(3) + get1(4))
total = (get2(0) - get1(0)) + (get2(1) - get1(1)) + (get2(2) - get1(2)) + idle + (get2(5) - get1(5)) + (get2(6) - get1(6)) + (get2(7) - get1(7))
if total > 0:
cpu_pct = round((1.0 - idle / total) * 100.0, 2)
# --- 采集内存使用率 ---
outm, errm = ssh_client.execute_command("cat /proc/meminfo")
mem_pct = 0.0
if not errm and outm.strip():
@ -128,48 +83,29 @@ class MetricsCollector:
elif line.startswith("MemAvailable:"):
ma = int(line.split()[1])
if mt > 0:
# 内存使用率 = (总内存 - 可用内存) / 总内存
mem_pct = round((1.0 - (ma / mt)) * 100.0, 2)
return cpu_pct, mem_pct
async def _save_metrics(self, node_id: int, hostname: str, cluster_id: int, cpu: float, mem: float):
"""
将采集到的实时指标数据持久化到数据库
更新 nodes 表的实时指标字段和最后心跳时间
"""
# 这里的 SessionLocal 绑定的 engine 可能在主线程 loop 中初始化
# 在 asyncio.run() 开启的新 loop 中使用它会报 Loop 冲突
from .db import engine
async with AsyncSession(engine) as session:
now = datetime.datetime.now(BJ_TZ)
# 使用原生 SQL 更新 nodes 表中的监控指标
await session.execute(text("UPDATE nodes SET cpu_usage=:cpu, memory_usage=:mem, last_heartbeat=:hb WHERE id=:nid"), {"cpu": cpu, "mem": mem, "hb": now, "nid": node_id})
await session.commit()
def _collect_node_metrics(self, node_id: int, hostname: str, ip: str, cluster_id: int):
"""
采集线程的主循环函数
在线程存活期间周期性调用 SSH 采集数据并保存
"""
cid = hostname
while cid in self.collectors:
try:
# 1. 读取远程指标
cpu, mem = self._read_cpu_mem(hostname, ip)
# 2. 由于当前处于同步线程中,需要使用 asyncio.run 驱动异步的数据库保存操作
asyncio.run(self._save_metrics(node_id, hostname, cluster_id, cpu, mem))
# 清除历史错误(如果本次成功)
if cid in self.last_errors:
del self.last_errors[cid]
except Exception as e:
# 记录采集过程中的异常信息
self.last_errors[cid] = str(e)
# 3. 等待下一个采集周期
time.sleep(self.collection_interval)
async def _get_table_columns(self, session: AsyncSession, table_name: str) -> set:
"""
获取指定数据库表的列名集合
使用了缓存机制避免重复查询 information_schema
"""
if table_name in self._columns_cache:
return self._columns_cache[table_name]
res = await session.execute(text("""
@ -181,5 +117,5 @@ class MetricsCollector:
self._columns_cache[table_name] = cols
return cols
# 全局指标采集器实例,供应用其他部分调用
metrics_collector = MetricsCollector()

@ -5,32 +5,26 @@ from ..config import BJ_TZ
from . import Base
class ChatSession(Base):
"""
聊天会话模型
记录用户与 AI 助手之间的对话会话
"""
__tablename__ = "chat_sessions"
id = Column(String, primary_key=True, index=True) # 会话唯一标识 (UUID)
user_id = Column(Integer, nullable=True, index=True) # 关联的用户 ID
title = Column(String, nullable=True) # 会话标题
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(BJ_TZ)) # 创建时间
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(BJ_TZ), onupdate=lambda: datetime.now(BJ_TZ)) # 更新时间
id = Column(String, primary_key=True, index=True) # UUID
user_id = Column(Integer, nullable=True, index=True) # Can be linked to a user
title = Column(String, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(BJ_TZ))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(BJ_TZ), onupdate=lambda: datetime.now(BJ_TZ))
# 一对多关系:一个会话包含多条消息
messages = relationship("ChatMessage", back_populates="session", cascade="all, delete-orphan", lazy="selectin")
class ChatMessage(Base):
"""
聊天消息模型
存储会话中的每一条具体消息内容
"""
__tablename__ = "chat_messages"
id = Column(Integer, primary_key=True, index=True)
session_id = Column(String, ForeignKey("chat_sessions.id"), nullable=False) # 所属会话 ID
role = Column(String, nullable=False) # 角色system, user, assistant, tool
content = Column(Text, nullable=False) # 消息内容
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(BJ_TZ)) # 发送时间
session_id = Column(String, ForeignKey("chat_sessions.id"), nullable=False)
role = Column(String, nullable=False) # system, user, assistant, tool
content = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(BJ_TZ))
# Optional: store tool calls or extra metadata if needed
# For now, we store JSON in content if it's complex, or just text.
session = relationship("ChatSession", back_populates="messages")

@ -3,15 +3,11 @@ from sqlalchemy import String, Integer, Float, TIMESTAMP
from . import Base
class ClusterMetric(Base):
"""
集群指标模型
记录集群整体的资源使用情况历史用于绘制监控图表
"""
__tablename__ = "cluster_metrics"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
cluster_id: Mapped[int] = mapped_column() # 集群 ID
cluster_name: Mapped[str] = mapped_column(String(100)) # 集群名称
cpu_avg: Mapped[float] = mapped_column(Float) # 集群平均 CPU 使用率 (%)
memory_avg: Mapped[float] = mapped_column(Float) # 集群平均内存使用率 (%)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 采集时间
cluster_id: Mapped[int] = mapped_column()
cluster_name: Mapped[str] = mapped_column(String(100))
cpu_avg: Mapped[float] = mapped_column(Float)
memory_avg: Mapped[float] = mapped_column(Float)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))

@ -4,28 +4,24 @@ from sqlalchemy.dialects.postgresql import UUID, JSONB, INET
from . import Base
class Cluster(Base):
"""
Hadoop 集群模型
记录集群的基本信息健康状态平均指标以及核心组件NameNode, RM的连接信息
"""
__tablename__ = "clusters"
id: Mapped[int] = mapped_column(primary_key=True)
uuid: Mapped[str] = mapped_column(UUID(as_uuid=False), unique=True) # 集群唯一标识
name: Mapped[str] = mapped_column(String(100), unique=True) # 集群名称
type: Mapped[str] = mapped_column(String(50)) # 集群类型 (如 Hadoop, Spark 等)
node_count: Mapped[int] = mapped_column(Integer, default=0) # 节点数量
health_status: Mapped[str] = mapped_column(String(20), default="unknown") # 健康状态 (healthy, warning, error)
cpu_avg: Mapped[float | None] = mapped_column(Float, nullable=True) # 平均 CPU 使用率
memory_avg: Mapped[float | None] = mapped_column(Float, nullable=True) # 平均内存使用率
namenode_ip: Mapped[str | None] = mapped_column(INET, nullable=True) # NameNode IP 地址
namenode_psw: Mapped[str | None] = mapped_column(String(255), nullable=True) # NameNode SSH 密码
rm_ip: Mapped[str | None] = mapped_column(INET, nullable=True) # ResourceManager IP 地址
rm_psw: Mapped[str | None] = mapped_column(String(255), nullable=True) # ResourceManager SSH 密码
description: Mapped[str | None] = mapped_column(String, nullable=True) # 集群描述
config_info: Mapped[dict | None] = mapped_column(JSONB, nullable=True) # 集群详细配置信息 (JSON)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 创建时间
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 更新时间
uuid: Mapped[str] = mapped_column(UUID(as_uuid=False), unique=True)
name: Mapped[str] = mapped_column(String(100), unique=True)
type: Mapped[str] = mapped_column(String(50))
node_count: Mapped[int] = mapped_column(Integer, default=0)
health_status: Mapped[str] = mapped_column(String(20), default="unknown")
cpu_avg: Mapped[float | None] = mapped_column(Float, nullable=True)
memory_avg: Mapped[float | None] = mapped_column(Float, nullable=True)
namenode_ip: Mapped[str | None] = mapped_column(INET, nullable=True)
namenode_psw: Mapped[str | None] = mapped_column(String(255), nullable=True)
rm_ip: Mapped[str | None] = mapped_column(INET, nullable=True)
rm_psw: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[str | None] = mapped_column(String, nullable=True)
config_info: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))
def to_dict(self) -> dict:
"""将集群对象转换为可序列化字典。"""

@ -5,29 +5,25 @@ from sqlalchemy import TIMESTAMP
from . import Base
class FaultRecord(Base):
"""
故障记录模型
记录系统中检测到的各类故障信息包括故障级别受影响节点原因分析及处理状态
"""
__tablename__ = "fault_records"
id: Mapped[int] = mapped_column(primary_key=True)
fault_id: Mapped[str] = mapped_column(String(32), unique=True) # 故障唯一编号
cluster_id: Mapped[int | None] = mapped_column(nullable=True) # 关联集群 ID
fault_type: Mapped[str] = mapped_column(String(50)) # 故障类型 (如 Network, Disk, Service)
fault_level: Mapped[str] = mapped_column(String(20), default="medium") # 故障等级 (low, medium, high, critical)
title: Mapped[str] = mapped_column(String(200)) # 故障标题
description: Mapped[str | None] = mapped_column(String, nullable=True) # 故障详细描述
affected_nodes: Mapped[dict | None] = mapped_column(JSONB, nullable=True) # 受影响的节点列表 (JSON)
affected_clusters: Mapped[dict | None] = mapped_column(JSONB, nullable=True) # 受影响的集群列表 (JSON)
root_cause: Mapped[str | None] = mapped_column(String, nullable=True) # 根本原因分析
repair_suggestion: Mapped[str | None] = mapped_column(String, nullable=True) # 修复建议
status: Mapped[str] = mapped_column(String(20), default="detected") # 处理状态 (detected, processing, resolved, closed)
assignee: Mapped[str | None] = mapped_column(String(50), nullable=True) # 负责人
reporter: Mapped[str] = mapped_column(String(50), default="system") # 报告者 (默认为系统自动检测)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 检测时间
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 更新时间
resolved_at: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True) # 解决时间
fault_id: Mapped[str] = mapped_column(String(32), unique=True)
cluster_id: Mapped[int | None] = mapped_column(nullable=True)
fault_type: Mapped[str] = mapped_column(String(50))
fault_level: Mapped[str] = mapped_column(String(20), default="medium")
title: Mapped[str] = mapped_column(String(200))
description: Mapped[str | None] = mapped_column(String, nullable=True)
affected_nodes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
affected_clusters: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
root_cause: Mapped[str | None] = mapped_column(String, nullable=True)
repair_suggestion: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[str] = mapped_column(String(20), default="detected")
assignee: Mapped[str | None] = mapped_column(String(50), nullable=True)
reporter: Mapped[str] = mapped_column(String(50), default="system")
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))
resolved_at: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True)
def to_dict(self) -> dict:
"""将故障记录转换为可序列化字典。"""

@ -3,21 +3,16 @@ from sqlalchemy import String, Integer, Text, TIMESTAMP, ForeignKey
from . import Base
class HadoopExecLog(Base):
"""
Hadoop 执行日志模型
记录用户对 Hadoop 集群执行的操作任务如启动停止重启服务等的审计日志
"""
__tablename__ = "hadoop_exec_logs"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
from_user_id: Mapped[int] = mapped_column(Integer, nullable=False) # 执行操作的用户 ID
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) # 操作的目标集群名称
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 操作描述或命令详情
start_time: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True) # 任务开始时间
end_time: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True) # 任务结束时间
from_user_id: Mapped[int] = mapped_column(Integer, nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
start_time: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True)
end_time: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True)
def to_dict(self) -> dict:
"""将执行日志转换为字典格式"""
return {
"id": self.id,
"from_user_id": self.from_user_id,

@ -3,21 +3,16 @@ from sqlalchemy import String, Integer, Text, TIMESTAMP
from . import Base
class HadoopLog(Base):
"""
Hadoop 日志模型
存储从各个集群节点采集到的原始日志信息包括日志时间来源节点标题服务名和具体内容
"""
__tablename__ = "hadoop_logs"
log_id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) # 所属集群名称
node_host: Mapped[str] = mapped_column(String(100), nullable=False) # 产生日志的节点主机名
title: Mapped[str | None] = mapped_column(String(255), nullable=True) # 日志标题(通常为 Hadoop 服务名,如 NameNode
info: Mapped[str | None] = mapped_column(Text, nullable=True) # 日志详细内容
log_time: Mapped[str] = mapped_column(TIMESTAMP(timezone=True), nullable=False) # 日志产生时间
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
node_host: Mapped[str] = mapped_column(String(100), nullable=False)
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
info: Mapped[str | None] = mapped_column(Text, nullable=True)
log_time: Mapped[str] = mapped_column(TIMESTAMP(timezone=True), nullable=False)
def to_dict(self) -> dict:
"""将日志对象转换为字典格式"""
return {
"log_id": self.log_id,
"cluster_name": self.cluster_name,

@ -3,16 +3,12 @@ from sqlalchemy import String, Integer, Float, TIMESTAMP
from . import Base
class NodeMetric(Base):
"""
节点指标模型
记录单个服务器节点的资源使用历史数据CPU内存等
"""
__tablename__ = "node_metrics"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
cluster_id: Mapped[int] = mapped_column() # 所属集群 ID
node_id: Mapped[int] = mapped_column() # 节点 ID
hostname: Mapped[str] = mapped_column(String(100)) # 节点主机名
cpu_usage: Mapped[float] = mapped_column(Float) # CPU 使用率 (%)
memory_usage: Mapped[float] = mapped_column(Float) # 内存使用率 (%)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 采集时间
cluster_id: Mapped[int] = mapped_column()
node_id: Mapped[int] = mapped_column()
hostname: Mapped[str] = mapped_column(String(100))
cpu_usage: Mapped[float] = mapped_column(Float)
memory_usage: Mapped[float] = mapped_column(Float)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))

@ -5,23 +5,20 @@ from sqlalchemy import TIMESTAMP, Float
from . import Base
class Node(Base):
"""
集群节点模型
记录集群中各个服务器节点的基本信息SSH 连接凭据以及实时采集的资源使用指标
"""
__tablename__ = "nodes"
id: Mapped[int] = mapped_column(primary_key=True)
uuid: Mapped[str] = mapped_column(UUID(as_uuid=False), unique=True) # 节点唯一标识
cluster_id: Mapped[int] = mapped_column() # 所属集群 ID
hostname: Mapped[str] = mapped_column(String(100)) # 主机名
ip_address: Mapped[str] = mapped_column(INET) # IP 地址
ssh_user: Mapped[str | None] = mapped_column(String(50), nullable=True) # SSH 用户名
ssh_password: Mapped[str | None] = mapped_column(String(255), nullable=True) # SSH 密码
status: Mapped[str] = mapped_column(String(20), default="unknown") # 节点状态 (online, offline, unknown)
cpu_usage: Mapped[float | None] = mapped_column(Float, nullable=True) # 实时 CPU 使用率 (%)
memory_usage: Mapped[float | None] = mapped_column(Float, nullable=True) # 实时内存使用率 (%)
disk_usage: Mapped[float | None] = mapped_column(Float, nullable=True) # 实时磁盘使用率 (%)
last_heartbeat: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True) # 最后心跳/采集时间
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 创建时间
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 更新时间
uuid: Mapped[str] = mapped_column(UUID(as_uuid=False), unique=True)
cluster_id: Mapped[int] = mapped_column()
hostname: Mapped[str] = mapped_column(String(100))
ip_address: Mapped[str] = mapped_column(INET)
ssh_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
ssh_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
# description: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[str] = mapped_column(String(20), default="unknown")
cpu_usage: Mapped[float | None] = mapped_column(Float, nullable=True)
memory_usage: Mapped[float | None] = mapped_column(Float, nullable=True)
disk_usage: Mapped[float | None] = mapped_column(Float, nullable=True)
last_heartbeat: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))

@ -4,19 +4,14 @@ from sqlalchemy.dialects.postgresql import UUID
from . import Base
class SysExecLog(Base):
"""
系统执行日志模型
记录用户在管理平台上进行的系统级操作审计如用户管理配置更改等
"""
__tablename__ = "sys_exec_logs"
operation_id: Mapped[str] = mapped_column(UUID(as_uuid=True), primary_key=True, server_default=text("uuid_generate_v4()")) # 操作唯一 ID
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False) # 执行操作的用户 ID
description: Mapped[str] = mapped_column(Text, nullable=False) # 操作详细描述
operation_time: Mapped[str] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")) # 操作执行时间
operation_id: Mapped[str] = mapped_column(UUID(as_uuid=True), primary_key=True, server_default=text("uuid_generate_v4()"))
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False)
operation_time: Mapped[str] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=text("now()"))
def to_dict(self) -> dict:
"""将操作日志转换为字典格式"""
return {
"operation_id": str(self.operation_id),
"user_id": self.user_id,

@ -4,18 +4,15 @@ from sqlalchemy import TIMESTAMP
from . import Base
class User(Base):
"""
系统用户模型
存储管理人员的账号信息用于系统登录和权限管理
"""
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
username: Mapped[str] = mapped_column(String(50), unique=True) # 用户名
email: Mapped[str] = mapped_column(String(100), unique=True) # 电子邮箱
password_hash: Mapped[str] = mapped_column(String(255)) # 密码哈希
full_name: Mapped[str] = mapped_column(String(100)) # 姓名/全称
is_active: Mapped[bool] = mapped_column(Boolean, default=True) # 是否激活
last_login: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True) # 最后登录时间
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 创建时间
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True)) # 更新时间
username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True)
password_hash: Mapped[str] = mapped_column(String(255))
full_name: Mapped[str] = mapped_column(String(100))
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
sort: Mapped[int] = mapped_column(default=0)
last_login: Mapped[str | None] = mapped_column(TIMESTAMP(timezone=True), nullable=True)
created_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))
updated_at: Mapped[str] = mapped_column(TIMESTAMP(timezone=True))

@ -20,68 +20,48 @@ router = APIRouter()
class DiagnoseRepairReq(BaseModel):
"""故障诊断与修复请求模型"""
cluster: str | None = Field(None, description="集群UUID")
node: str | None = Field(None, description="节点主机名")
timeFrom: str | None = Field(None, description="ISO起始时间")
keywords: str | None = Field(None, description="关键词")
auto: bool = Field(True, description="是否允许自动修复")
maxSteps: int = Field(3, ge=1, le=6, description="最多工具步数")
model: str | None = Field(None, description="使用的模型名称")
model: str | None = Field(None, description="使用的模型")
class ChatReq(BaseModel):
"""AI 聊天请求模型"""
sessionId: str = Field(..., description="会话ID")
message: str = Field(..., description="用户输入的消息内容")
message: str = Field(..., description="用户输入")
stream: bool = Field(False, description="是否使用流式输出")
context: dict | None = Field(None, description="上下文包含node, agent, model等元数据")
context: dict | None = Field(None, description="上下文包含node, agent, model等")
class HistoryReq(BaseModel):
"""获取历史记录的请求模型"""
sessionId: str
def _get_username(u) -> str:
"""获取当前登录用户的用户名"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None) or "system"
def _get_internal_session_id(user, session_id: str) -> str:
"""
生成内部会话 ID
为了隔离不同用户的会话使用 'username:session_id' 的格式
"""
uname = _get_username(user)
return f"{uname}:{session_id}"
@router.post("/ai/diagnose-repair")
async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
自动化诊断与修复接口
1. 根据请求参数节点关键词等聚合最近的结构化日志上下文
2. 调用诊断代理diagnosis_agent结合 LLM 和运维工具进行分析
3. 返回诊断结论及修复建议如果允许 auto=True则可能已执行修复
"""
try:
# 聚合简要日志上下文(结构化日志)
filters = []
if req.node:
filters.append(HadoopLog.node_host == req.node)
if req.keywords:
# 模糊匹配 info 字段中的关键词
# 这里简化为 info 包含关键词
filters.append(HadoopLog.info.ilike(f"%{req.keywords}%"))
stmt = select(HadoopLog).limit(100).order_by(HadoopLog.log_time.desc())
for f in filters:
stmt = stmt.where(f)
rows = (await db.execute(stmt)).scalars().all()
# 提取前 50 条日志作为上下文
ctx_logs = [r.to_dict() for r in rows[:50]]
context = {"cluster": req.cluster, "node": req.node, "logs": ctx_logs}
uname = _get_username(user)
# 运行诊断与修复流程
result = await run_diagnose_and_repair(db, uname, context, auto=req.auto, max_steps=req.maxSteps, model=req.model)
return result
except HTTPException:
@ -89,52 +69,39 @@ async def diagnose_repair(req: DiagnoseRepairReq, user=Depends(get_current_user)
except Exception:
raise HTTPException(status_code=500, detail="server_error")
@router.get("/ai/history")
async def get_history(sessionId: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""获取指定会话历史聊天记录"""
"""获取会话历史"""
internal_id = _get_internal_session_id(user, sessionId)
stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.asc())
rows = (await db.execute(stmt)).scalars().all()
messages = [{"role": r.role, "content": r.content} for r in rows]
return {"messages": messages}
@router.post("/ai/chat")
async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
AI 交互式对话接口
1. 管理会话状态与历史记录
2. 构造系统提示词System Prompt定义 AI 的角色和行为规范
3. 支持 Function Calling工具调用AI 可以根据用户需求调用集群管理日志读取故障检测等工具
4. 支持流式或非流式输出
"""
try:
internal_id = _get_internal_session_id(user, req.sessionId)
user_id = user.get("id") if isinstance(user, dict) else getattr(user, "id", None)
# 确保会话存在
session_stmt = select(ChatSession).where(ChatSession.id == internal_id)
session = (await db.execute(session_stmt)).scalars().first()
if not session:
session = ChatSession(id=internal_id, user_id=user_id, title=req.message[:20])
db.add(session)
# 核心逻辑提示词:引导 AI 优先使用运维工具进行诊断
system_prompt = (
"你是 Hadoop 运维诊断助手。输出中文,优先给出根因、影响范围、证据与建议。"
"当用户询问“故障/异常/报错/不可用/打不开/任务失败”等问题时,优先调用 detect_cluster_faults"
"必要时再用 read_cluster_log 补充读取对应组件日志。"
"当用户询问进程/端口/资源/版本等日常运维信息时,优先调用 run_cluster_command例如 jps/df/free/hdfs_report/yarn_node_list"
)
# 合并动态上下文
if req.context:
if req.context.get("agent"):
system_prompt += f" Your name is {req.context['agent']}."
if req.context.get("node"):
system_prompt += f" You are currently analyzing node: {req.context['node']}."
# 加载最近 12 条历史消息
hist_stmt = select(ChatMessage).where(ChatMessage.session_id == internal_id).order_by(ChatMessage.created_at.desc()).limit(12)
hist_rows = (await db.execute(hist_stmt)).scalars().all()
hist_rows = hist_rows[::-1]
@ -144,26 +111,27 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
messages.append({"role": r.role, "content": r.content})
messages.append({"role": "user", "content": req.message})
# 保存用户发送的消息
user_msg = ChatMessage(session_id=internal_id, role="user", content=req.message)
db.add(user_msg)
llm = LLMClient()
target_model = req.context.get("model") if req.context else None
# 加载可用运维工具的 Schema (用于 OpenAI 格式的 Function Calling)
# 默认加载所有可用运维工具
chat_tools = openai_tools_schema()
# 第一次请求 LLM
if req.stream:
# 流式暂不支持工具调用后的二次生成(为了简化),如果检测到可能需要工具,先走非流式
# 或者这里可以根据需求调整,目前先保持非流式处理工具逻辑
pass
resp = await llm.chat(messages, tools=chat_tools, stream=False, model=target_model)
choices = resp.get("choices") or []
if not choices:
raise HTTPException(status_code=502, detail="llm_unavailable")
msg = choices[0].get("message") or {}
tool_calls = msg.get("tool_calls") or []
# 处理 AI 发起的工具调用
if tool_calls:
messages.append(msg)
for tc in tool_calls:
@ -177,8 +145,6 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
tool_result = {"error": "unknown_tool"}
uname = _get_username(user)
# 分发工具逻辑
if name == "web_search":
tool_result = await tool_web_search(args.get("query"), args.get("max_results", 5))
elif name == "start_cluster":
@ -188,13 +154,35 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
elif name == "read_log":
tool_result = await tool_read_log(db, uname, args.get("node"), args.get("path"), int(args.get("lines", 200)), args.get("pattern"), args.get("sshUser"))
elif name == "read_cluster_log":
tool_result = await tool_read_cluster_log(db, uname, args.get("cluster_uuid"), args.get("log_type"), args.get("node_hostname"), int(args.get("lines", 100)))
tool_result = await tool_read_cluster_log(
db,
uname,
args.get("cluster_uuid"),
args.get("log_type"),
args.get("node_hostname"),
int(args.get("lines", 100))
)
elif name == "detect_cluster_faults":
tool_result = await tool_detect_cluster_faults(db, uname, args.get("cluster_uuid"), args.get("components"), args.get("node_hostname"), int(args.get("lines", 200)))
tool_result = await tool_detect_cluster_faults(
db,
uname,
args.get("cluster_uuid"),
args.get("components"),
args.get("node_hostname"),
int(args.get("lines", 200)),
)
elif name == "run_cluster_command":
tool_result = await tool_run_cluster_command(db, uname, args.get("cluster_uuid"), args.get("command_key"), args.get("target"), args.get("node_hostname"), int(args.get("timeout", 30)), int(args.get("limit_nodes", 20)))
tool_result = await tool_run_cluster_command(
db,
uname,
args.get("cluster_uuid"),
args.get("command_key"),
args.get("target"),
args.get("node_hostname"),
int(args.get("timeout", 30)),
int(args.get("limit_nodes", 20)),
)
# 将工具执行结果加入对话
messages.append({
"role": "tool",
"tool_call_id": tc.get("id"),
@ -202,7 +190,6 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
"content": json.dumps(tool_result, ensure_ascii=False)
})
# 工具调用完成后,进行第二次 LLM 请求,生成最终回复
if req.stream:
return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools, model=target_model)
else:
@ -212,15 +199,12 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
raise HTTPException(status_code=502, detail="llm_unavailable_after_tool")
msg = choices[0].get("message") or {}
else:
# AI 直接回复,不涉及工具调用
if req.stream:
return await handle_streaming_chat(llm, messages, internal_id, db, tools=chat_tools, model=target_model)
# 处理最终回复
reply = msg.get("content") or ""
reasoning = msg.get("reasoning_content") or "" # 部分模型支持思维链展示
reasoning = msg.get("reasoning_content") or ""
# 保存助手回复的消息
asst_msg = ChatMessage(session_id=internal_id, role="assistant", content=reply)
db.add(asst_msg)
await db.commit()
@ -233,9 +217,7 @@ async def ai_chat(req: ChatReq, user=Depends(get_current_user), db: AsyncSession
print(f"AI Chat Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"server_error: {str(e)}")
async def handle_streaming_chat(llm: LLMClient, messages: list, session_id: str, db: AsyncSession, tools=None, model: str = None):
"""处理流式对话响应"""
async def event_generator():
full_reply = ""
full_reasoning = ""
@ -255,11 +237,9 @@ async def handle_streaming_chat(llm: LLMClient, messages: list, session_id: str,
full_reply += content
if reasoning:
full_reasoning += reasoning
# 返回符合 SSE (Server-Sent Events) 规范的格式
yield f"data: {json.dumps({'content': content, 'reasoning': reasoning}, ensure_ascii=False)}\n\n"
finally:
# 流式结束时,将完整的回复内容持久化到数据库
try:
if full_reply:
asst_msg = ChatMessage(session_id=session_id, role="assistant", content=full_reply)

@ -14,31 +14,26 @@ from ..config import now_bj
router = APIRouter()
class LoginRequest(BaseModel):
"""登录请求模型"""
username: str
password: str
class RegisterRequest(BaseModel):
"""注册请求模型"""
username: str
email: str
password: str
fullName: str
async def _get_user_id(db: AsyncSession, username: str) -> int | None:
"""根据用户名获取用户 ID"""
res = await db.execute(text("SELECT id FROM users WHERE username=:u LIMIT 1"), {"u": username})
row = res.first()
return row[0] if row else None
async def _get_role_id(db: AsyncSession, role_key: str) -> int | None:
"""根据角色 Key 获取角色 ID"""
res = await db.execute(text("SELECT id FROM roles WHERE role_key=:k LIMIT 1"), {"k": role_key})
row = res.first()
return row[0] if row else None
async def _ensure_observer_role(db: AsyncSession) -> int:
"""确保数据库中存在 'observer' (观察员) 角色,不存在则创建"""
rid = await _get_role_id(db, "observer")
if rid is not None:
return rid
@ -55,7 +50,6 @@ async def _ensure_observer_role(db: AsyncSession) -> int:
return rid2
async def _map_user_role(db: AsyncSession, username: str, role_key: str) -> None:
"""为指定用户分配角色"""
uid = await _get_user_id(db, username)
if uid is None:
raise HTTPException(status_code=500, detail="user_not_found_after_register")
@ -65,13 +59,11 @@ async def _map_user_role(db: AsyncSession, username: str, role_key: str) -> None
rid = await _ensure_observer_role(db)
else:
raise HTTPException(status_code=400, detail="role_not_exist")
# 先清理旧映射,再插入新映射(简单处理:一个用户一个角色)
await db.execute(text("DELETE FROM user_role_mapping WHERE user_id=:uid"), {"uid": uid})
await db.execute(text("INSERT INTO user_role_mapping(user_id, role_id) VALUES(:uid, :rid)"), {"uid": uid, "rid": rid})
await db.commit()
async def _get_user_roles(db: AsyncSession, user_id: int) -> list[str]:
"""查询指定用户拥有的所有角色 Key"""
res = await db.execute(
text("SELECT r.role_key FROM roles r JOIN user_role_mapping urm ON r.id = urm.role_id WHERE urm.user_id = :uid"),
{"uid": user_id},
@ -79,7 +71,6 @@ async def _get_user_roles(db: AsyncSession, user_id: int) -> list[str]:
return [row[0] for row in res.all()]
async def _get_role_permissions(db: AsyncSession, role_keys: list[str]) -> list[str]:
"""根据角色列表查询关联的所有权限 Key"""
if not role_keys:
return []
res = await db.execute(
@ -96,13 +87,6 @@ async def _get_role_permissions(db: AsyncSession, role_keys: list[str]) -> list[
@router.post("/user/login")
async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
"""
用户登录接口
1. 支持内置演示账号的快速登录
2. 支持数据库用户的哈希密码校验
3. 登录成功后返回 JWT Token用户信息角色及权限列表
"""
# 演示账号逻辑
demo = {"admin": "admin123", "ops": "ops123", "obs": "obs123"}
if req.username in demo and req.password == demo[req.username]:
exp = now_bj() + timedelta(minutes=JWT_EXPIRE_MINUTES)
@ -112,7 +96,7 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
uid = await _get_user_id(db, req.username)
roles = await _get_user_roles(db, uid) if uid else []
if not roles:
# 如果 DB 中没记录,给个默认映射
# 如果 DB 中没记录,给个默认
role_map = {"admin": ["admin"], "ops": ["operator"], "obs": ["observer"]}
roles = role_map.get(req.username, [])
@ -126,8 +110,6 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
"roles": roles,
"permissions": permissions
}
# 常规数据库用户登录逻辑
try:
result = await db.execute(select(User).where(User.username == req.username).limit(1))
user = result.scalars().first()
@ -137,8 +119,6 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
raise HTTPException(status_code=403, detail="inactive_user")
if not bcrypt.verify(req.password, user.password_hash):
raise HTTPException(status_code=401, detail="invalid_credentials")
# 更新最后登录时间
await db.execute(
update(User).where(User.id == user.id).values(last_login=func.now(), updated_at=func.now())
)
@ -165,12 +145,6 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
@router.post("/user/register")
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
"""
用户注册接口
1. 包含详尽的用户名邮箱密码复杂度和姓名的校验
2. 检查用户名和邮箱的唯一性
3. 注册成功后自动分配 'observer' 角色并返回登录态
"""
try:
errors: list[dict] = []
# 用户名校验3-50位字母开头支持字母/数字/下划线
@ -204,8 +178,6 @@ async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
exists_email = await db.execute(select(User.id).where(User.email == req.email).limit(1))
if exists_email.scalars().first():
raise HTTPException(status_code=400, detail={"message": "该邮箱已被绑定", "code": "email_exists"})
# 密码哈希存储
password_hash = bcrypt.hash(req.password)
user = User(
username=req.username,
@ -220,8 +192,6 @@ async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
db.add(user)
await db.flush()
await db.commit()
# 注册后默认分配观察员角色
await _map_user_role(db, req.username, "observer")
permissions = await _get_role_permissions(db, ["observer"])

@ -15,16 +15,10 @@ router = APIRouter()
def _get_username(u) -> str:
"""
辅助函数从用户对象可能是字典或 SQLAlchemy 模型中提取用户名
"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None)
class NodeCreateItem(BaseModel):
"""
创建集群时关联节点的详细信息请求模型
"""
hostname: str
ip_address: str
ssh_user: str
@ -32,11 +26,8 @@ class NodeCreateItem(BaseModel):
description: str | None = None
class ClusterCreateRequest(BaseModel):
"""
注册新集群的完整请求模型包含集群基本信息和节点列表
"""
name: str
type: str # 支持类型hadoop, spark, kubernetes
type: str
node_count: int
health_status: str
cpu_avg: float | None = None
@ -51,28 +42,17 @@ class ClusterCreateRequest(BaseModel):
@router.get("/clusters")
async def list_clusters(user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取集群列表接口
1. 获取当前登录用户的用户名
2. 查询用户在数据库中的 ID
3. 通过关联表 user_cluster_mapping 查找该用户有权访问的所有集群 ID
4. 返回这些集群的详细配置信息
"""
"""按当前用户归属返回其可访问的集群列表。"""
try:
name = _get_username(user)
# 1. 获取用户 ID
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": name})
uid_row = uid_res.first()
if not uid_row:
return {"clusters": []}
# 2. 查询用户关联的集群 ID 列表
ids_res = await db.execute(text("SELECT cluster_id FROM user_cluster_mapping WHERE user_id=:uid"), {"uid": uid_row[0]})
cluster_ids = [r[0] for r in ids_res.all()]
if not cluster_ids:
return {"clusters": []}
# 3. 查询集群详细信息
result = await db.execute(select(Cluster).where(Cluster.id.in_(cluster_ids)))
rows = result.scalars().all()
data = []
@ -104,20 +84,10 @@ async def create_cluster(
user=Depends(PermissionChecker(["cluster:register"])),
db: AsyncSession = Depends(get_db)
):
"""
注册集群接口需要 cluster:register 权限
1. 校验请求参数类型状态节点数
2. 通过 SSH 连接到 NameNode获取 HDFS 的真实集群 UUID确保物理集群的唯一性
3. 检查 UUID 是否已存在
- 若已存在说明该集群已被他人注册仅建立当前用户与该集群的权限关联
- 若不存在则进入新集群注册流程
- 校验集群名称唯一性
- 对所有节点进行 SSH 连通性测试
- 保存集群基本信息及所有节点信息到数据库
4. 根据用户名admin 或其他分配相应的角色admin operator并建立关联
"""
"""注册一个集群并建立当前用户的归属映射。"""
try:
name = _get_username(user)
# 移除硬编码的角色检查PermissionChecker 已经处理了权限校验
# 参数校验:类型与状态
valid_types = {"hadoop", "spark", "kubernetes"}
@ -132,7 +102,7 @@ async def create_cluster(
if errors:
raise HTTPException(status_code=400, detail={"errors": errors})
# 1. 获取 HDFS 集群真实 UUID (通过 SSH 从 NameNode 获取)
# 1. 获取 HDFS 集群真实 UUID (从 NameNode 获取)
cluster_uuid, err = get_hdfs_cluster_id(str(req.namenode_ip), req.nodes[0].ssh_user, req.nodes[0].ssh_password)
if not cluster_uuid:
raise HTTPException(status_code=400, detail={"errors": [{"field": "namenode_ip", "message": f"无法获取集群ID: {err}"}]})
@ -152,7 +122,7 @@ async def create_cluster(
if name_exists.scalars().first():
raise HTTPException(status_code=400, detail={"errors": [{"field": "name", "message": "集群名称已存在"}]})
# SSH 连通性预检查:确保所有节点都可连接
# SSH 连通性预检查
ssh_errors: list[dict] = []
for idx, n_req in enumerate(req.nodes):
ip = getattr(n_req, "ip_address", None) or getattr(n_req, "ip", None)
@ -173,7 +143,6 @@ async def create_cluster(
new_uuid = cluster_uuid
# 创建集群记录
c = Cluster(
uuid=new_uuid,
name=req.name,
@ -192,9 +161,9 @@ async def create_cluster(
updated_at=now_bj(),
)
db.add(c)
await db.flush() # 刷新以获取 c.id 用于节点关联
await db.flush() # 获取 c.id
# 批量插入节点记录
# 插入节点
for n_req in req.nodes:
node_uuid = str(uuidlib.uuid4())
node = Node(
@ -213,14 +182,12 @@ async def create_cluster(
# 3. 建立用户映射 (无论集群是新注册还是已存在)
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": name})
uid_row = uid_res.first()
# 角色分配逻辑admin 用户赋予 admin 角色,其他用户赋予 operator 角色
# 简化逻辑:如果是 admin 用户则赋予 admin 角色,否则赋予 operator 角色
role_key = "admin" if name == "admin" else "operator"
rid_res = await db.execute(text("SELECT id FROM roles WHERE role_key=:rk LIMIT 1"), {"rk": role_key})
rid_row = rid_res.first()
if uid_row and rid_row:
# 建立用户-集群-角色的关联关系,使用 ON CONFLICT 避免重复插入
await db.execute(
text("INSERT INTO user_cluster_mapping(user_id, cluster_id, role_id) VALUES (:uid,:cid,:rid) ON CONFLICT (user_id, cluster_id) DO NOTHING"),
{"uid": uid_row[0], "cid": c.id, "rid": rid_row[0]}
@ -247,28 +214,19 @@ async def delete_cluster(
user=Depends(PermissionChecker(["cluster:delete"])),
db: AsyncSession = Depends(get_db)
):
"""
注销集群接口需要 cluster:delete 权限
1. 校验传入的 UUID 格式
2. 根据 UUID 查找集群
3. 从数据库中物理删除集群记录关联节点和指标将依据外键约束级联处理
4. 清理用户与该集群的映射关系
"""
"""注销指定集群,并清理用户归属映射。"""
try:
name = _get_username(user)
# 移除硬编码的角色检查
try:
uo = uuidlib.UUID(uuid)
except Exception:
raise HTTPException(status_code=400, detail={"errors": [{"field": "uuid", "message": "UUID 格式不正确"}]})
res = await db.execute(select(Cluster).where(Cluster.uuid == str(uo)).limit(1))
c = res.scalars().first()
if not c:
return {"ok": True}
# 执行物理删除
await db.execute(delete(Cluster).where(Cluster.id == c.id))
# 清理关联映射
await db.execute(text("DELETE FROM user_cluster_mapping WHERE cluster_id=:cid"), {"cid": c.id})
await db.commit()
return {"ok": True}
@ -276,4 +234,3 @@ async def delete_cluster(
raise
except Exception:
raise HTTPException(status_code=500, detail="server_error")

@ -15,17 +15,14 @@ router = APIRouter()
def _get_username(u) -> str:
"""获取当前登录用户的用户名"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None)
def _now():
"""获取当前北京时间"""
return now_bj()
def _map_level(level: str) -> str:
"""将前端或采集到的等级映射为统一的日志等级"""
lv = (level or "").lower()
if lv in ("critical", "fatal"):
return "FATAL"
@ -37,7 +34,6 @@ def _map_level(level: str) -> str:
class FaultCreate(BaseModel):
"""创建故障记录的请求模型"""
id: str | None = None
type: str
level: str
@ -49,7 +45,6 @@ class FaultCreate(BaseModel):
class FaultUpdate(BaseModel):
"""更新故障记录的请求模型"""
status: str | None = None
title: str | None = None
@ -64,26 +59,16 @@ async def list_faults(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
):
"""
获取故障记录列表
- 支持按集群节点起始时间进行过滤
- 支持分页page, size
- 故障记录存储在 hadoop_logs 表中title 固定为 'fault'
"""
try:
# 基础查询语句,过滤故障类型
stmt = select(HadoopLog).where(HadoopLog.title == "fault")
count_stmt = select(func.count(HadoopLog.log_id)).where(HadoopLog.title == "fault")
# 按集群过滤
if cluster:
stmt = stmt.where(HadoopLog.cluster_name == cluster)
count_stmt = count_stmt.where(HadoopLog.cluster_name == cluster)
# 按节点过滤
if node:
stmt = stmt.where(HadoopLog.node_host == node)
count_stmt = count_stmt.where(HadoopLog.node_host == node)
# 按时间过滤
if time_from:
try:
tf = datetime.fromisoformat(time_from.replace("Z", "+00:00"))
@ -96,7 +81,6 @@ async def list_faults(
except Exception:
pass
# 分页和排序(按时间倒序)
stmt = stmt.order_by(HadoopLog.log_time.desc()).offset((page - 1) * size).limit(size)
rows = (await db.execute(stmt)).scalars().all()
total = (await db.execute(count_stmt)).scalar() or 0
@ -105,7 +89,6 @@ async def list_faults(
for r in rows:
meta = {}
try:
# 解析存储在 info 字段中的 JSON 扩展信息
if r.info:
meta = json.loads(r.info)
except Exception:
@ -131,25 +114,19 @@ async def list_faults(
@router.post("/faults")
async def create_fault(req: FaultCreate, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
手动创建故障记录接口
- 仅限 admin ops 角色使用
- 将故障信息序列化为 JSON 存入 HadoopLog.info
"""
try:
uname = _get_username(user)
if uname not in {"admin", "ops"}:
raise HTTPException(status_code=403, detail="not_allowed")
# 确定集群显示名称
# 确定集群名称
cluster_name = req.cluster or "unknown"
if req.cluster and "-" in req.cluster: # 如果传入的是 UUID尝试查询真实名称
if req.cluster and "-" in req.cluster: # 可能是 UUID
res = await db.execute(select(Cluster.name).where(Cluster.uuid == req.cluster).limit(1))
name = res.scalars().first()
if name:
cluster_name = name
# 处理故障发生时间
ts = _now()
if req.created:
try:
@ -161,12 +138,11 @@ async def create_fault(req: FaultCreate, user=Depends(get_current_user), db: Asy
except Exception:
pass
# 构建元数据并保存
meta = {"type": req.type, "status": req.status, "title": req.title, "cluster": req.cluster, "node": req.node}
log = HadoopLog(
cluster_name=cluster_name,
node_host=req.node or "unknown",
title="fault", # 固定标识为故障
title="fault",
info=json.dumps(meta, ensure_ascii=False),
log_time=ts
)
@ -182,16 +158,11 @@ async def create_fault(req: FaultCreate, user=Depends(get_current_user), db: Asy
@router.put("/faults/{fid}")
async def update_fault(fid: int, req: FaultUpdate, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
更新故障记录接口
- 用于修改故障状态已处理误报或标题
"""
try:
uname = _get_username(user)
if uname not in {"admin", "ops"}:
raise HTTPException(status_code=403, detail="not_allowed")
# 查找记录
res = await db.execute(select(HadoopLog).where(HadoopLog.log_id == fid, HadoopLog.title == "fault").limit(1))
row = res.scalars().first()
if not row:
@ -204,7 +175,6 @@ async def update_fault(fid: int, req: FaultUpdate, user=Depends(get_current_user
except Exception:
pass
# 更新 JSON 字段中的信息
if req.status is not None:
meta["status"] = req.status
if req.title is not None:
@ -222,10 +192,6 @@ async def update_fault(fid: int, req: FaultUpdate, user=Depends(get_current_user
@router.delete("/faults/{fid}")
async def delete_fault(fid: int, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
删除故障记录接口
- 仅限管理员或运维操作
"""
try:
uname = _get_username(user)
if uname not in {"admin", "ops"}:

@ -14,7 +14,6 @@ router = APIRouter()
class ExecLogCreate(BaseModel):
"""创建执行日志的请求模型"""
from_user_id: int
cluster_name: str
description: str | None = None
@ -23,22 +22,16 @@ class ExecLogCreate(BaseModel):
class ExecLogUpdate(BaseModel):
"""更新执行日志的请求模型"""
description: str | None = None
start_time: str | None = None
end_time: str | None = None
def _now() -> datetime:
"""获取当前北京时间"""
return now_bj()
def _parse_time(s: str | None) -> datetime | None:
"""
解析 ISO 格式的时间字符串为带北京时区的 datetime 对象
支持处理 'Z' 后缀
"""
if not s:
return None
try:
@ -52,13 +45,7 @@ def _parse_time(s: str | None) -> datetime | None:
@router.get("/exec-logs")
async def list_exec_logs(user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取执行日志列表接口
- 关联查询 users 表以获取执行者的用户名
- 按开始时间倒序排列
"""
try:
# 执行联表查询
stmt = (
select(HadoopExecLog, User.username)
.join(User, HadoopExecLog.from_user_id == User.id)
@ -69,10 +56,8 @@ async def list_exec_logs(user=Depends(get_current_user), db: AsyncSession = Depe
items = []
for log, username in rows:
# 转换为字典并添加用户名信息
d = log.to_dict()
d["username"] = username
# 移除敏感或不必要的内部 ID
if "from_user_id" in d:
del d["from_user_id"]
items.append(d)
@ -85,10 +70,6 @@ async def list_exec_logs(user=Depends(get_current_user), db: AsyncSession = Depe
@router.post("/exec-logs")
async def create_exec_log(req: ExecLogCreate, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
创建执行日志接口
- 记录任务开始执行时的元数据
"""
try:
st = _parse_time(req.start_time)
et = _parse_time(req.end_time)
@ -101,7 +82,7 @@ async def create_exec_log(req: ExecLogCreate, user=Depends(get_current_user), db
end_time=et
)
db.add(row)
await db.flush() # 刷新以获取自动生成的 ID
await db.flush()
await db.commit()
return {"ok": True, "id": row.id}
except HTTPException:
@ -113,15 +94,10 @@ async def create_exec_log(req: ExecLogCreate, user=Depends(get_current_user), db
@router.put("/exec-logs/{log_id}")
async def update_exec_log(log_id: int, req: ExecLogUpdate, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
更新执行日志接口
- 常用于任务结束时更新结束时间和执行结果描述
"""
try:
st = _parse_time(req.start_time)
et = _parse_time(req.end_time)
values: dict = {}
# 仅更新非空字段
if req.description is not None:
values["description"] = req.description
if st is not None:
@ -132,7 +108,6 @@ async def update_exec_log(log_id: int, req: ExecLogUpdate, user=Depends(get_curr
if not values:
return {"ok": True}
# 执行更新操作
await db.execute(update(HadoopExecLog).where(HadoopExecLog.id == log_id).values(**values))
await db.commit()
return {"ok": True}
@ -144,9 +119,6 @@ async def update_exec_log(log_id: int, req: ExecLogUpdate, user=Depends(get_curr
@router.delete("/exec-logs/{log_id}")
async def delete_exec_log(log_id: int, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
删除执行日志接口
"""
try:
await db.execute(delete(HadoopExecLog).where(HadoopExecLog.id == log_id))
await db.commit()

@ -29,10 +29,6 @@ from ..schemas import (
router = APIRouter()
async def _ensure_metrics_schema(db: AsyncSession):
"""
确保指标相关的数据库表结构存在
如果表不存在则创建并检查必要的列是否存在
"""
await db.execute(text("""
CREATE TABLE IF NOT EXISTS node_metrics (
id SERIAL PRIMARY KEY,
@ -54,7 +50,6 @@ async def _ensure_metrics_schema(db: AsyncSession):
created_at TIMESTAMPTZ
)
"""))
# 动态添加缺失的列,确保表结构最新
await db.execute(text("ALTER TABLE node_metrics ADD COLUMN IF NOT EXISTS node_id INTEGER"))
await db.execute(text("ALTER TABLE node_metrics ADD COLUMN IF NOT EXISTS hostname VARCHAR(100)"))
await db.execute(text("ALTER TABLE node_metrics ADD COLUMN IF NOT EXISTS cpu_usage DOUBLE PRECISION"))
@ -69,10 +64,6 @@ async def _ensure_metrics_schema(db: AsyncSession):
await db.commit()
def _parse_time(s: str | None) -> datetime | None:
"""
解析 ISO 格式的时间字符串并处理时区
默认将无时区的时间视为北京时间
"""
if not s:
return None
try:
@ -94,10 +85,6 @@ async def list_logs(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
):
"""
获取 Hadoop 结构化日志列表
支持按集群节点关键词及起始时间进行过滤并支持分页
"""
try:
stmt = select(HadoopLog)
count_stmt = select(func.count(HadoopLog.log_id))
@ -108,7 +95,6 @@ async def list_logs(
if node:
filters.append(HadoopLog.node_host == node)
if source:
# 在标题、信息和主机名中进行模糊匹配
like = f"%{source}%"
filters.append(or_(HadoopLog.title.ilike(like), HadoopLog.info.ilike(like), HadoopLog.node_host.ilike(like)))
tf = _parse_time(time_from)
@ -119,7 +105,6 @@ async def list_logs(
stmt = stmt.where(f)
count_stmt = count_stmt.where(f)
# 排序并分页
stmt = stmt.order_by(HadoopLog.log_time.desc()).offset((page - 1) * size).limit(size)
rows = (await db.execute(stmt)).scalars().all()
total = (await db.execute(count_stmt)).scalar() or 0
@ -143,9 +128,6 @@ async def list_logs(
raise HTTPException(status_code=500, detail="server_error")
async def get_node_ip(db: AsyncSession, node_name: str) -> str:
"""
根据节点主机名从数据库获取其 IP 地址
"""
result = await db.execute(select(Node.ip_address).where(Node.hostname == node_name))
ip = result.scalar_one_or_none()
if not ip:
@ -154,23 +136,20 @@ async def get_node_ip(db: AsyncSession, node_name: str) -> str:
@router.get("/hadoop/nodes/")
async def get_hadoop_nodes(user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取所有 Hadoop 节点的主机名列表
"""
# 获取所有有关联集群的节点主机名
"""Get list of all Hadoop nodes"""
# Assuming all nodes in DB are relevant, or filter by Cluster type if needed
stmt = select(Node.hostname).join(Cluster)
# Optional: .where(Cluster.type.ilike('%hadoop%'))
result = await db.execute(stmt)
nodes = result.scalars().all()
return NodeListResponse(nodes=nodes)
@router.get("/hadoop/logs/{node_name}/{log_type}/", response_model=LogResponse)
async def get_hadoop_log(node_name: str, log_type: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
从特定节点读取指定类型的原始日志内容
"""
"""Get log from a specific Hadoop node"""
ip = await get_node_ip(db, node_name)
try:
# 调用 log_reader 读取日志文件内容
# Read log content
log_content = log_reader.read_log(node_name, log_type, ip=ip)
return LogResponse(
node_name=node_name,
@ -182,9 +161,7 @@ async def get_hadoop_log(node_name: str, log_type: str, user=Depends(get_current
@router.get("/hadoop/logs/all/{log_type}/", response_model=MultiLogResponse)
async def get_all_hadoop_nodes_log(log_type: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
从所有 Hadoop 节点读取指定类型的日志
"""
"""Get logs from all Hadoop nodes"""
stmt = select(Node.hostname, Node.ip_address).join(Cluster)
result = await db.execute(stmt)
nodes_data = result.all()
@ -192,7 +169,7 @@ async def get_all_hadoop_nodes_log(log_type: str, user=Depends(get_current_user)
nodes_list = [{"name": n[0], "ip": str(n[1])} for n in nodes_data]
try:
# 批量从所有节点读取日志
# Read logs from all nodes
logs = log_reader.read_all_nodes_log(nodes_list, log_type)
return MultiLogResponse(logs=logs)
except Exception as e:
@ -200,12 +177,10 @@ async def get_all_hadoop_nodes_log(log_type: str, user=Depends(get_current_user)
@router.get("/hadoop/logs/files/{node_name}/", response_model=LogFilesResponse)
async def get_hadoop_log_files(node_name: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取特定节点上可用的 Hadoop 日志文件列表
"""
"""Get list of log files on a specific Hadoop node"""
ip = await get_node_ip(db, node_name)
try:
# 获取日志文件列表
# Get log files list
log_files = log_reader.get_log_files_list(node_name, ip=ip)
return LogFilesResponse(
node_name=node_name,
@ -214,12 +189,10 @@ async def get_hadoop_log_files(node_name: str, user=Depends(get_current_user), d
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 日志采集管理接口
# Log collection management endpoints
@router.get("/hadoop/collectors/status/")
async def get_hadoop_collectors_status(user=Depends(get_current_user)):
"""
获取所有 Hadoop 日志采集器的当前状态
"""
"""Get status of all Hadoop log collectors"""
status = log_collector.get_collectors_status()
return {
"collectors": status,
@ -228,9 +201,7 @@ async def get_hadoop_collectors_status(user=Depends(get_current_user)):
@router.post("/hadoop/collectors/start/{node_name}/{log_type}/")
async def start_hadoop_collector(node_name: str, log_type: str, interval: int = 5, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
启动特定节点和日志类型的增量采集器
"""
"""Start log collection for a specific Hadoop node and log type"""
ip = await get_node_ip(db, node_name)
try:
log_collector.start_collection(node_name, log_type, ip=ip, interval=interval)
@ -243,9 +214,8 @@ async def start_hadoop_collector(node_name: str, log_type: str, interval: int =
@router.post("/hadoop/collectors/stop/{node_name}/{log_type}/")
async def stop_hadoop_collector(node_name: str, log_type: str, user=Depends(get_current_user)):
"""
停止特定节点和日志类型的采集器
"""
"""Stop log collection for a specific Hadoop node and log type"""
# stop doesn't need IP as it just stops the thread by ID
try:
log_collector.stop_collection(node_name, log_type)
return {
@ -256,9 +226,7 @@ async def stop_hadoop_collector(node_name: str, log_type: str, user=Depends(get_
@router.post("/hadoop/collectors/stop/all/")
async def stop_all_hadoop_collectors(user=Depends(get_current_user)):
"""
停止所有正在运行的日志采集器
"""
"""Stop all Hadoop log collectors"""
try:
log_collector.stop_all_collections()
return {
@ -269,9 +237,7 @@ async def stop_all_hadoop_collectors(user=Depends(get_current_user)):
@router.post("/hadoop/collectors/set-interval/{interval}/")
async def set_hadoop_collection_interval(interval: int, user=Depends(get_current_user)):
"""
设置所有采集器的采集时间间隔
"""
"""Set collection interval for all Hadoop collectors"""
try:
log_collector.set_collection_interval(interval)
return {
@ -282,9 +248,7 @@ async def set_hadoop_collection_interval(interval: int, user=Depends(get_current
@router.post("/hadoop/collectors/set-log-dir/{log_dir}/")
async def set_hadoop_log_directory(log_dir: str, user=Depends(get_current_user)):
"""
设置 Hadoop 日志的基础目录
"""
"""Set log directory for all Hadoop collectors"""
try:
log_collector.set_log_dir(log_dir)
return {
@ -295,9 +259,7 @@ async def set_hadoop_log_directory(log_dir: str, user=Depends(get_current_user))
@router.post("/hadoop/nodes/{node_name}/execute/")
async def execute_hadoop_command(node_name: str, command: str, timeout: int = 30, user=Depends(get_current_user)):
"""
在特定 Hadoop 节点上执行 SSH 命令
"""
"""Execute a command on a specific Hadoop node"""
try:
from sqlalchemy import select
from ..db import SessionLocal
@ -307,9 +269,9 @@ async def execute_hadoop_command(node_name: str, command: str, timeout: int = 30
ip = res.scalar_one_or_none()
if not ip:
raise HTTPException(status_code=404, detail=f"Node {node_name} not found")
# 获取 SSH 客户端并执行命令
ssh_client = ssh_manager.get_connection(node_name, ip=str(ip))
# Execute command with timeout
stdout, stderr = ssh_client.execute_command_with_timeout(command, timeout)
return {
@ -324,34 +286,25 @@ async def execute_hadoop_command(node_name: str, command: str, timeout: int = 30
@router.post("/hadoop/collectors/start-by-cluster/{cluster_uuid}/")
async def start_collectors_by_cluster(cluster_uuid: str, interval: int = 5, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
按集群 UUID 为所有节点启动所有发现的服务日志采集
"""
"""Start log collection for all nodes of the cluster (by UUID), only for existing services"""
try:
# 获取集群 ID
cid_res = await db.execute(select(Cluster.id).where(Cluster.uuid == cluster_uuid).limit(1))
cid = cid_res.scalar_one_or_none()
if cid is None:
raise HTTPException(status_code=404, detail="cluster_not_found")
# 获取集群下的所有节点
nodes_res = await db.execute(select(Node.hostname, Node.ip_address).where(Node.cluster_id == cid))
rows = nodes_res.all()
if not rows:
return {"started": 0, "nodes": []}
started = []
for hn, ip in rows:
ip_s = str(ip)
files = []
try:
# 寻找工作的日志目录并列出文件
log_reader.find_working_log_dir(hn, ip_s)
files = log_reader.get_log_files_list(hn, ip=ip_s)
except Exception:
files = []
# 根据文件名识别 Hadoop 服务类型
services = []
for fn in files:
f = fn.lower()
@ -367,9 +320,7 @@ async def start_collectors_by_cluster(cluster_uuid: str, interval: int = 5, user
services.append("nodemanager")
elif "historyserver" in f:
services.append("historyserver")
services = list(set(services))
# 为识别出的每个服务启动采集
for t in services:
ok = False
try:
@ -378,7 +329,6 @@ async def start_collectors_by_cluster(cluster_uuid: str, interval: int = 5, user
ok = False
if ok:
started.append(f"{hn}_{t}")
return {"started": len(started), "nodes": started, "interval": interval}
except HTTPException:
raise
@ -387,25 +337,19 @@ async def start_collectors_by_cluster(cluster_uuid: str, interval: int = 5, user
@router.post("/hadoop/collectors/backfill-by-cluster/{cluster_uuid}/")
async def backfill_logs_by_cluster(cluster_uuid: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
历史日志回填扫描集群内所有节点的日志文件并将其内容保存到数据库
"""
try:
cid_res = await db.execute(select(Cluster.id).where(Cluster.uuid == cluster_uuid).limit(1))
cid = cid_res.scalar_one_or_none()
if cid is None:
raise HTTPException(status_code=404, detail="cluster_not_found")
nodes_res = await db.execute(select(Node.hostname, Node.ip_address).where(Node.cluster_id == cid))
rows = nodes_res.all()
if not rows:
return {"backfilled": 0, "details": []}
details = []
for hn, ip in rows:
ip_s = str(ip)
ssh_client = ssh_manager.get_connection(hn, ip=ip_s)
# 候选的日志目录
candidates = [
"/opt/module/hadoop-3.1.3/logs",
"/usr/local/hadoop/logs",
@ -416,13 +360,11 @@ async def backfill_logs_by_cluster(cluster_uuid: str, user=Depends(get_current_u
"/var/log/hadoop",
]
base = None
# 寻找存在的日志目录
for d in candidates:
out, err = ssh_client.execute_command(f"ls -1 {d} 2>/dev/null")
if not err and out.strip():
base = d
break
services = []
count = 0
if base:
@ -431,7 +373,6 @@ async def backfill_logs_by_cluster(cluster_uuid: str, user=Depends(get_current_u
for fn in out.splitlines():
f = fn.lower()
t = None
# 识别服务类型
if "namenode" in f:
t = "namenode"
elif "secondarynamenode" in f:
@ -444,17 +385,13 @@ async def backfill_logs_by_cluster(cluster_uuid: str, user=Depends(get_current_u
t = "nodemanager"
elif "historyserver" in f:
t = "historyserver"
if t:
services.append(t)
# 读取整个文件内容并保存
out2, err2 = ssh_client.execute_command(f"cat {base}/{fn} 2>/dev/null")
if not err2 and out2:
log_collector._save_log_chunk(hn, t, out2)
count += out2.count("\n")
details.append({"node": hn, "services": list(set(services)), "lines": count})
total_lines = sum(d["lines"] for d in details)
return {"backfilled": total_lines, "details": details}
except HTTPException:
@ -464,36 +401,26 @@ async def backfill_logs_by_cluster(cluster_uuid: str, user=Depends(get_current_u
@router.post("/metrics/{cluster_uuid}/")
async def sync_metrics(cluster_uuid: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
手动同步集群及节点的 CPU 和内存指标
"""
try:
from sqlalchemy import select
try:
metrics_collector.stop_all()
except Exception:
pass
# 获取集群信息
cid_res = await db.execute(select(Cluster.id, Cluster.name).where(Cluster.uuid == cluster_uuid).limit(1))
row = cid_res.first()
if not row:
raise HTTPException(status_code=404, detail="cluster_not_found")
cid, cname = row
# 获取节点信息
nodes_res = await db.execute(select(Node.id, Node.hostname, Node.ip_address).where(Node.cluster_id == cid))
rows = nodes_res.all()
now = now_bj()
details = []
for nid, hn, ip in rows:
ssh_client = ssh_manager.get_connection(hn, ip=str(ip))
# 计算 CPU 使用率 (通过采样两次 /proc/stat)
out1, err1 = ssh_client.execute_command("cat /proc/stat | head -n 1")
time.sleep(0.5)
out2, err2 = ssh_client.execute_command("cat /proc/stat | head -n 1")
cpu_pct = 0.0
if not err1 and not err2 and out1.strip() and out2.strip():
p1 = out1.strip().split()
@ -506,8 +433,6 @@ async def sync_metrics(cluster_uuid: str, user=Depends(get_current_user), db: As
total = (get2(0) - get1(0)) + (get2(1) - get1(1)) + (get2(2) - get1(2)) + idle + (get2(5) - get1(5)) + (get2(6) - get1(6)) + (get2(7) - get1(7))
if total > 0:
cpu_pct = round((1.0 - idle / total) * 100.0, 2)
# 计算内存使用率 (通过 /proc/meminfo)
outm, errm = ssh_client.execute_command("cat /proc/meminfo")
mem_pct = 0.0
if not errm and outm.strip():
@ -520,28 +445,15 @@ async def sync_metrics(cluster_uuid: str, user=Depends(get_current_user), db: As
ma = int(line.split()[1])
if mt > 0:
mem_pct = round((1.0 - (ma / mt)) * 100.0, 2)
details.append({"node": hn, "cpu": cpu_pct, "memory": mem_pct})
# 计算集群平均指标
if details:
ca = round(sum(d["cpu"] for d in details) / len(details), 3)
ma = round(sum(d["memory"] for d in details) / len(details), 3)
else:
ca = 0.0
ma = 0.0
return {
"cluster": {
"cpu_avg": round(ca, 2),
"memory_avg": round(ma, 2),
"time": now.isoformat(),
"cluster_name": cname
},
"nodes": details
}
return {"cluster": {"cpu_avg": round(ca, 2), "memory_avg": round(ma, 2), "time": now.isoformat(), "cluster_name": cname}, "nodes": details}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@ -7,16 +7,10 @@ router = APIRouter()
@router.get("/health")
async def health_check(db: AsyncSession = Depends(get_db)):
"""
系统健康检查接口
1. 验证 API 服务是否正常运行
2. 尝试执行简单的 SQL 查询SELECT 1来验证数据库连接是否正常
返回服务状态及数据库连接状态
"""
"""健康检查,包括数据库连接验证。"""
try:
# 尝试执行一个简单的查询来验证数据库连接
await db.execute(text("SELECT 1"))
return {"status": "ok", "database": "connected"}
except Exception as e:
# 如果数据库连接失败,仍返回 200但在响应中说明数据库断开
return {"status": "ok", "database": f"disconnected: {str(e)}"}

@ -12,16 +12,10 @@ router = APIRouter()
def _get_username(u) -> str:
"""从用户对象中提取用户名。"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None)
async def _ensure_access(db: AsyncSession, username: str, cluster_uuid: str) -> int | None:
"""
权限检查辅助函数
验证指定用户是否拥有访问特定集群通过 UUID 指定的权限
如果验证通过返回集群的数据库 ID否则返回 None
"""
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": username})
uid_row = uid_res.first()
if not uid_row:
@ -30,7 +24,6 @@ async def _ensure_access(db: AsyncSession, username: str, cluster_uuid: str) ->
cid = cid_res.scalars().first()
if not cid:
return None
# 检查关联表是否存在记录
auth_res = await db.execute(text("SELECT 1 FROM user_cluster_mapping WHERE user_id=:uid AND cluster_id=:cid LIMIT 1"), {"uid": uid_row[0], "cid": cid})
if not auth_res.first():
return None
@ -44,12 +37,6 @@ async def start_collectors_by_cluster(
user=Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
按集群启动指标采集器
1. 校验用户对该集群的访问权限
2. 获取集群下所有节点的信息ID, 主机名, IP
3. 调用 metrics_collector 为这些节点启动异步指标采集任务
"""
try:
name = _get_username(user)
cid = await _ensure_access(db, name, cluster_uuid)
@ -76,12 +63,7 @@ async def get_collectors_status(
user=Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
查询指标采集器的运行状态
1. 获取全局采集器的运行状态和错误信息
2. 如果指定了集群 UUID则过滤出该集群下节点的采集状态
3. 返回包括是否在运行活跃采集器数量具体采集状态及错误信息的 JSON
"""
"""查询指标采集器的状态"""
try:
name = _get_username(user)
# 即使校验失败或发生错误,也返回一个 200 结构的友好响应,而不是让接口崩掉
@ -136,12 +118,6 @@ async def stop_collectors_by_cluster(
user=Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
按集群停止指标采集器
1. 校验用户权限
2. 获取集群下所有节点的主机名
3. 停止这些节点的指标采集任务
"""
try:
name = _get_username(user)
cid = await _ensure_access(db, name, cluster_uuid)
@ -163,12 +139,7 @@ async def stop_collectors_by_cluster(
@router.get("/metrics/cpu_trend")
async def cpu_trend(cluster: str = Query(...), user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取集群级别的 CPU 使用率趋势数据模拟数据
1. 权限校验
2. 计算集群内所有节点的平均 CPU 使用率
3. 基于平均值生成一个简易的趋势波形
"""
"""获取指定集群的 CPU 使用率趋势数据。"""
try:
name = _get_username(user)
cid = await _ensure_access(db, name, cluster)
@ -188,9 +159,7 @@ async def cpu_trend(cluster: str = Query(...), user=Depends(get_current_user), d
@router.get("/metrics/memory_usage")
async def memory_usage(cluster: str = Query(...), user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取集群级别的内存使用百分比
"""
"""获取指定集群的内存使用情况(单位:百分比)。"""
try:
name = _get_username(user)
cid = await _ensure_access(db, name, cluster)
@ -208,9 +177,7 @@ async def memory_usage(cluster: str = Query(...), user=Depends(get_current_user)
@router.get("/metrics/cpu_trend_node")
async def cpu_trend_node(cluster: str = Query(...), node: str = Query(...), user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取特定节点的 CPU 使用率趋势数据
"""
"""获取指定节点的 CPU 使用率趋势数据。"""
try:
name = _get_username(user)
cid = await _ensure_access(db, name, cluster)
@ -229,9 +196,7 @@ async def cpu_trend_node(cluster: str = Query(...), node: str = Query(...), user
@router.get("/metrics/memory_usage_node")
async def memory_usage_node(cluster: str = Query(...), node: str = Query(...), user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取特定节点的内存使用百分比
"""
"""获取指定节点的内存使用情况(单位:百分比)。"""
try:
name = _get_username(user)
cid = await _ensure_access(db, name, cluster)
@ -246,4 +211,3 @@ async def memory_usage_node(cluster: str = Query(...), node: str = Query(...), u
raise
except Exception:
raise HTTPException(status_code=500, detail="server_error")

@ -13,18 +13,10 @@ router = APIRouter()
def _get_username(u) -> str:
"""
辅助函数从用户对象可能是字典或模型中安全地提取用户名
"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None)
def _status_to_contract(s: str) -> str:
"""
内部状态码转换为前端展示的合同状态码
healthy -> running
unhealthy -> stopped
"""
if s == "healthy":
return "running"
if s == "unhealthy":
@ -33,18 +25,12 @@ def _status_to_contract(s: str) -> str:
def _fmt_percent(v: float | None) -> str:
"""
将浮点数值格式化为百分比字符串
"""
if v is None:
return "-"
return f"{int(round(v))}%"
def _fmt_updated(ts: datetime | None) -> str:
"""
将最后更新时间转换为易读的相对时间描述刚刚5分钟前
"""
if not ts:
return "-"
now = now_bj()
@ -57,41 +43,26 @@ def _fmt_updated(ts: datetime | None) -> str:
class NodeDetail(BaseModel):
"""
节点详细信息的 Pydantic 模型
"""
name: str
metrics: dict
@router.get("/nodes")
async def list_nodes(cluster: str = Query(...), user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取指定集群的节点列表
1. 校验用户对该集群的访问权限
2. 从数据库中查询该集群下的所有节点
3. 将节点的运行状态资源使用率CPU内存及更新时间进行格式化处理后返回
"""
"""拉取指定集群的节点列表。"""
try:
name = _get_username(user)
# 获取用户 ID
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": name})
uid_row = uid_res.first()
if not uid_row:
return {"nodes": []}
# 获取集群 ID
cid_res = await db.execute(select(Cluster.id).where(Cluster.uuid == cluster).limit(1))
cid = cid_res.scalars().first()
if not cid:
return {"nodes": []}
# 权限校验:用户是否被授权访问该集群
auth_res = await db.execute(text("SELECT 1 FROM user_cluster_mapping WHERE user_id=:uid AND cluster_id=:cid LIMIT 1"), {"uid": uid_row[0], "cid": cid})
if not auth_res.first():
raise HTTPException(status_code=403, detail="not_allowed")
# 查询节点列表
result = await db.execute(select(Node).where(Node.cluster_id == cid).limit(500))
rows = result.scalars().all()
data = [
@ -114,30 +85,22 @@ async def list_nodes(cluster: str = Query(...), user=Depends(get_current_user),
@router.get("/nodes/{name}")
async def node_detail(name: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
获取单个节点的详细信息
1. 确定当前用户的权限范围可访问的集群 ID 列表
2. 在用户有权访问的集群中查找匹配主机名的节点
3. 返回包含 CPU内存磁盘使用率及最后心跳时间在内的详细指标
"""
"""查询节点详情。"""
try:
name_u = _get_username(user)
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": name_u})
uid_row = uid_res.first()
if not uid_row:
raise HTTPException(status_code=404, detail="not_found")
# 仅返回用户可访问集群中的该节点
ids_res = await db.execute(text("SELECT cluster_id FROM user_cluster_mapping WHERE user_id=:uid"), {"uid": uid_row[0]})
cluster_ids = [r[0] for r in ids_res.all()]
if not cluster_ids:
raise HTTPException(status_code=404, detail="not_found")
res = await db.execute(select(Node).where(Node.hostname == name, Node.cluster_id.in_(cluster_ids)).limit(1))
n = res.scalars().first()
if not n:
raise HTTPException(status_code=404, detail="not_found")
return NodeDetail(
name=n.hostname,
metrics={
@ -155,4 +118,3 @@ async def node_detail(name: str, user=Depends(get_current_user), db: AsyncSessio
raise HTTPException(status_code=500, detail="server_error")

@ -22,67 +22,58 @@ router = APIRouter()
def _now() -> datetime:
"""返回当前北京时间。"""
"""返回当前 UTC 时间。"""
return now_bj()
def _get_username(u) -> str:
"""
辅助函数从用户对象中提取用户名
支持 Pydantic 模型和字典格式默认为 'system'
"""
"""提取用户名。"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None) or "system"
def _require_ops(u):
"""
权限检查校验用户是否具有运维权限仅限 admin ops 用户
如果权限不足抛出 403 异常
"""
"""校验用户是否具有运维权限。"""
name = _get_username(u)
if name not in {"admin", "ops"}:
raise HTTPException(status_code=403, detail="not_allowed")
async def _find_accessible_node(db: AsyncSession, user_name: str, hostname: str) -> Node | None:
"""
辅助函数在用户有权访问的集群范围内查找指定主机名的节点
"""
"""在用户可访问的集群中查找指定主机名的节点。"""
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": user_name})
uid_row = uid_res.first()
if not uid_row:
return None
# 查询用户关联的所有集群 ID
ids_res = await db.execute(text("SELECT cluster_id FROM user_cluster_mapping WHERE user_id=:uid"), {"uid": uid_row[0]})
cluster_ids = [r[0] for r in ids_res.all()]
if not cluster_ids:
return None
# 在这些集群中查找目标节点
res = await db.execute(select(Node).where(Node.hostname == hostname, Node.cluster_id.in_(cluster_ids)).limit(1))
return res.scalars().first()
def _gen_exec_id() -> str:
"""生成 32 位十六进制格式的执行记录 ID。"""
"""生成执行记录ID。"""
return uuidlib.uuid4().hex[:32]
class ReadLogReq(BaseModel):
"""
读取日志请求模型
"""
node: str = Field(..., description="目标节点主机名")
path: str = Field(..., description="日志文件路径")
lines: int = Field(200, ge=1, le=5000, description="读取行数")
pattern: str | None = Field(None, description="可选过滤正则表达式")
sshUser: str | None = Field(None, description="SSH 登录用户名(可选)")
timeout: int = Field(20, ge=1, le=120, description="SSH 命令执行超时时间(秒)")
pattern: str | None = Field(None, description="可选过滤正则")
sshUser: str | None = Field(None, description="SSH 用户名(可选)")
timeout: int = Field(20, ge=1, le=120, description="命令超时时间")
async def _write_exec_log(db: AsyncSession, operation_id: str, description: str, user_id: int):
"""
记录系统操作日志到数据库
"""
"""写入系统操作日志。"""
row = SysExecLog(
user_id=user_id,
description=description,
@ -95,24 +86,15 @@ async def _write_exec_log(db: AsyncSession, operation_id: str, description: str,
@router.post("/ops/read-log")
async def read_log(req: ReadLogReq, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
远程读取节点日志接口
1. 权限校验需运维权限
2. 验证用户是否有权访问该节点所属的集群
3. 构建 SSH 命令使用 tail 读取末尾行支持 grep 过滤
4. 通过 SSH 执行命令并获取输出
5. 记录操作审计日志并返回结果
"""
"""读取远端日志文件内容,支持可选筛选。"""
try:
_require_ops(user)
uname = _get_username(user)
# 获取用户 ID默认为 1 (admin)
# 假设这里需要 user_id从 user 对象获取或查询
user_id = getattr(user, "id", 1)
node = await _find_accessible_node(db, uname, req.node)
if not node:
raise HTTPException(status_code=404, detail="node_not_found")
# 安全转义路径和过滤模式
path_q = shlex.quote(req.path)
cmd = f"tail -n {req.lines} {path_q}"
if req.pattern:
@ -120,16 +102,13 @@ async def read_log(req: ReadLogReq, user=Depends(get_current_user), db: AsyncSes
cmd = f"{cmd} | grep -E {pat_q}"
start = _now()
# 执行远程命令
code, out, err = await run_remote_command(str(getattr(node, "ip_address", "")), req.sshUser or "", cmd, timeout=req.timeout)
desc = f"Read log: {req.path} on {req.node} (Exit: {code})"
await _write_exec_log(db, None, desc, user_id)
if code != 0:
# 如果退出码不为 0视为执行失败可能是路径不存在或权限问题
raise HTTPException(status_code=500, detail="exec_failed")
lines = [ln for ln in out.splitlines()]
return {"exitCode": code, "lines": lines}
except HTTPException:
@ -139,9 +118,7 @@ async def read_log(req: ReadLogReq, user=Depends(get_current_user), db: AsyncSes
async def _write_hadoop_exec_log(db: AsyncSession, user_id: int, cluster_name: str, description: str, start_time: datetime, end_time: datetime):
"""
记录 Hadoop 集群运维操作启动/停止的审计日志
"""
"""写入 Hadoop 执行审计日志。"""
row = HadoopExecLog(
from_user_id=user_id,
cluster_name=cluster_name,
@ -160,14 +137,7 @@ async def start_cluster(
user=Depends(PermissionChecker(["cluster:start"])),
db: AsyncSession = Depends(get_db)
):
"""
一键启动 Hadoop 集群接口
1. 校验 UUID 格式及用户操作权限
2. 获取集群的主节点NameNode ResourceManager配置
3. 分别在 NN RM 节点上远程执行启动脚本start-dfs.sh, start-yarn.sh
4. 收集执行日志并根据执行结果更新集群健康状态
5. 记录详细的运维审计日志
"""
"""启动集群:在 NameNode 执行 hsfsstart在 ResourceManager 执行 yarnstart。"""
try:
# UUID 格式校验
try:
@ -178,13 +148,13 @@ async def start_cluster(
uname = _get_username(user)
user_id = getattr(user, "id", 1)
# 1. 查找集群信息
# 1. 查找集群
res = await db.execute(select(Cluster).where(Cluster.uuid == cluster_uuid).limit(1))
cluster = res.scalars().first()
if not cluster:
raise HTTPException(status_code=404, detail="cluster_not_found")
# 2. 确定 SSH 登录用户(从节点配置中获取,默认 hadoop
# 2. 获取 SSH 用户 (从关联节点中获取,默认为 hadoop)
node_res = await db.execute(select(Node).where(Node.cluster_id == cluster.id).limit(1))
node = node_res.scalars().first()
ssh_user = node.ssh_user if node and node.ssh_user else "hadoop"
@ -192,7 +162,7 @@ async def start_cluster(
start_time = _now()
logs = []
# 3. 在 NameNode 节点执行 HDFS 启动脚本
# 3. 在 NameNode 执行 start-dfs.sh
if cluster.namenode_ip and cluster.namenode_psw:
try:
def run_nn_start():
@ -203,7 +173,7 @@ async def start_cluster(
except Exception as e:
logs.append(f"NameNode ({cluster.namenode_ip}) start failed: {str(e)}")
# 4. 在 ResourceManager 节点执行 YARN 启动脚本
# 4. 在 ResourceManager 执行 start-yarn.sh
if cluster.rm_ip and cluster.rm_psw:
try:
def run_rm_start():
@ -216,8 +186,8 @@ async def start_cluster(
end_time = _now()
# 5. 根据脚本执行结果更新集群状态
# 如果日志中包含 "failed" 字样,则标记为 error否则标记为 healthy
# 5. 更新集群状态 (仅当所有尝试都未抛出异常时)
# 改进:检查是否有失败日志
has_failed = any("failed" in log.lower() for log in logs)
if not has_failed:
cluster.health_status = "healthy"
@ -227,7 +197,7 @@ async def start_cluster(
cluster.updated_at = end_time
await db.flush()
# 6. 写入审计日志
# 6. 记录日志
full_desc = " | ".join(logs)
await _write_hadoop_exec_log(db, user_id, cluster.name, f"Start Cluster: {full_desc}", start_time, end_time)
@ -245,11 +215,7 @@ async def stop_cluster(
user=Depends(PermissionChecker(["cluster:stop"])),
db: AsyncSession = Depends(get_db)
):
"""
一键停止 Hadoop 集群接口
逻辑与启动集群类似但在 NN RM 上执行停止脚本stop-dfs.sh, stop-yarn.sh
停止后集群状态会被设置为 unknown
"""
"""停止集群:在 NameNode 执行 hsfsstop在 ResourceManager 执行 yarnstop。"""
try:
# UUID 格式校验
try:
@ -274,7 +240,7 @@ async def stop_cluster(
start_time = _now()
logs = []
# 3. 停止 HDFS
# 3. 在 NameNode 执行 stop-dfs.sh
if cluster.namenode_ip and cluster.namenode_psw:
try:
def run_nn_stop():
@ -285,7 +251,7 @@ async def stop_cluster(
except Exception as e:
logs.append(f"NameNode ({cluster.namenode_ip}) stop failed: {str(e)}")
# 4. 停止 YARN
# 4. 在 ResourceManager 执行 stop-yarn.sh
if cluster.rm_ip and cluster.rm_psw:
try:
def run_rm_stop():
@ -298,12 +264,12 @@ async def stop_cluster(
end_time = _now()
# 5. 更新集群状态为 unknown
# 5. 更新集群状态
cluster.health_status = "unknown"
cluster.updated_at = end_time
await db.flush()
# 6. 记录审计日志
# 6. 记录日志
full_desc = " | ".join(logs)
await _write_hadoop_exec_log(db, user_id, cluster.name, f"Stop Cluster: {full_desc}", start_time, end_time)
@ -319,4 +285,3 @@ async def stop_cluster(

@ -5,10 +5,6 @@ router = APIRouter()
@router.get("/user/me")
async def me(user = Depends(get_current_user)):
"""
获取当前登录用户的个人资料
从认证中间件注入的 user 对象中提取用户名全名和激活状态并返回
"""
if isinstance(user, dict):
return {"username": user.get("username"), "fullName": user.get("full_name"), "isActive": user.get("is_active")}
return {"username": user.username, "fullName": user.full_name, "isActive": user.is_active}

@ -10,9 +10,6 @@ from datetime import datetime
router = APIRouter()
class SysExecLogCreate(BaseModel):
"""
创建系统执行日志的请求模型
"""
user_id: int
description: str
@ -23,12 +20,7 @@ async def list_sys_exec_logs(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
):
"""
分页获取系统执行日志列表
按操作时间倒序排列支持自定义页码和每页数量
"""
try:
# 构建分页查询语句
stmt = select(SysExecLog).order_by(SysExecLog.operation_time.desc()).offset((page - 1) * size).limit(size)
count_stmt = select(func.count(SysExecLog.operation_id))
@ -45,9 +37,6 @@ async def list_sys_exec_logs(
@router.post("/sys-exec-logs")
async def create_sys_exec_log(req: SysExecLogCreate, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
手动创建一条系统执行日志记录
"""
try:
row = SysExecLog(
user_id=req.user_id,
@ -62,15 +51,11 @@ async def create_sys_exec_log(req: SysExecLogCreate, user=Depends(get_current_us
@router.delete("/sys-exec-logs/{operation_id}")
async def delete_sys_exec_log(operation_id: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""
根据操作 IDUUID删除指定的执行日志
"""
try:
# 注意operation_id 是 UUID 字符串
# Note: operation_id is UUID
await db.execute(delete(SysExecLog).where(SysExecLog.operation_id == operation_id))
await db.commit()
return {"ok": True}
except Exception as e:
print(f"Error deleting sys exec log: {e}")
raise HTTPException(status_code=500, detail="server_error")

@ -12,52 +12,45 @@ from ..config import now_bj
router = APIRouter()
# 角色覆盖映射,用于演示或特殊逻辑
ROLE_OVERRIDES: dict[str, str] = {}
class CreateUserRequest(BaseModel):
"""创建用户请求模型"""
username: str
email: str
role: str
status: str
sort: int = 0
class UpdateUserRequest(BaseModel):
"""更新用户请求模型"""
role: str | None = None
status: str | None = None
sort: int | None = None
class ChangePasswordRequest(BaseModel):
"""修改密码请求模型"""
currentPassword: str
newPassword: str
def _status_to_active(status: str) -> bool:
"""前端状态字符转换为数据库布尔值"""
return status == "enabled"
def _active_to_status(active: bool) -> str:
"""数据库布尔值转换为前端状态字符"""
return "enabled" if active else "disabled"
async def _get_user_id(db: AsyncSession, username: str) -> int | None:
"""根据用户名获取用户 ID"""
res = await db.execute(text("SELECT id FROM users WHERE username=:u LIMIT 1"), {"u": username})
row = res.first()
return row[0] if row else None
async def _get_role_id(db: AsyncSession, role_key: str) -> int | None:
"""根据角色 Key 获取角色 ID"""
res = await db.execute(text("SELECT id FROM roles WHERE role_key=:k LIMIT 1"), {"k": role_key})
row = res.first()
return row[0] if row else None
async def _get_role_key(db: AsyncSession, username: str) -> str | None:
"""获取指定用户当前关联的角色 Key"""
res = await db.execute(
text(
"SELECT r.role_key FROM roles r JOIN user_role_mapping m ON r.id=m.role_id JOIN users u ON u.id=m.user_id WHERE u.username=:u LIMIT 1"
@ -68,21 +61,18 @@ async def _get_role_key(db: AsyncSession, username: str) -> str | None:
return row[0] if row else None
async def _set_user_role(db: AsyncSession, username: str, role_key: str) -> bool:
"""为指定用户设置/更新角色映射"""
uid = await _get_user_id(db, username)
if uid is None:
return False
rid = await _get_role_id(db, role_key)
if rid is None:
return False
# 清理旧角色并插入新角色
await db.execute(text("DELETE FROM user_role_mapping WHERE user_id=:uid"), {"uid": uid})
await db.execute(text("INSERT INTO user_role_mapping(user_id, role_id) VALUES(:uid, :rid)"), {"uid": uid, "rid": rid})
await db.commit()
return True
def _role_or_default(username: str) -> str:
"""获取用户角色,若无则返回默认值(主要用于演示逻辑)"""
if username in ROLE_OVERRIDES:
return ROLE_OVERRIDES[username]
if username == "admin":
@ -95,15 +85,10 @@ def _role_or_default(username: str) -> str:
def _get_username(u) -> str:
"""从不同类型的 User 对象中提取用户名"""
return getattr(u, "username", None) or (u.get("username") if isinstance(u, dict) else None)
def _require_permission(user, permission: str):
"""
权限检查辅助函数
如果当前用户不具备指定权限则抛出 403 异常
"""
perms = user.get("permissions", []) if isinstance(user, dict) else getattr(user, "permissions", [])
if permission not in perms:
raise HTTPException(status_code=403, detail=f"Permission denied: {permission}")
@ -111,10 +96,9 @@ def _require_permission(user, permission: str):
@router.get("/users")
async def list_users(user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""获取所有用户列表,包含其角色和状态"""
try:
_require_permission(user, "auth:manage")
result = await db.execute(select(User).limit(500))
result = await db.execute(select(User).order_by(User.sort.desc()).limit(500))
rows = result.scalars().all()
users = []
for u in rows:
@ -125,6 +109,7 @@ async def list_users(user=Depends(get_current_user), db: AsyncSession = Depends(
"email": u.email,
"role": rk or "observer",
"status": _active_to_status(u.is_active),
"sort": u.sort,
}
)
return {"users": users}
@ -136,11 +121,9 @@ async def list_users(user=Depends(get_current_user), db: AsyncSession = Depends(
@router.post("/users")
async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""管理员手动创建新用户"""
try:
_require_permission(user, "auth:manage")
errors: list[dict] = []
# 参数校验
if not (3 <= len(req.username) <= 50) or not re.fullmatch(r"^[A-Za-z][A-Za-z0-9_]{2,49}$", req.username or ""):
errors.append({"field": "username", "message": "用户名需以字母开头,支持字母/数字/下划线长度3-50"})
if not re.fullmatch(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", req.email or ""):
@ -152,7 +135,6 @@ async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db
if errors:
raise HTTPException(status_code=400, detail={"errors": errors})
# 唯一性检查
exists_username = await db.execute(select(User.id).where(User.username == req.username).limit(1))
if exists_username.scalars().first():
raise HTTPException(status_code=409, detail={"errors": [{"field": "username", "message": "用户名已存在"}]})
@ -160,7 +142,6 @@ async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db
if exists_email.scalars().first():
raise HTTPException(status_code=409, detail={"errors": [{"field": "email", "message": "邮箱已存在"}]})
# 创建用户(使用临时密码)
temp_password = "TempPass#123"
password_hash = bcrypt.hash(temp_password)
now = now_bj()
@ -170,6 +151,7 @@ async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db
password_hash=password_hash,
full_name=req.username,
is_active=_status_to_active(req.status),
sort=req.sort,
last_login=None,
created_at=now,
updated_at=now,
@ -177,7 +159,6 @@ async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db
db.add(user_obj)
await db.flush()
await db.commit()
# 分配角色
ok = await _set_user_role(db, req.username, req.role)
if not ok:
raise HTTPException(status_code=400, detail={"errors": [{"field": "role", "message": "角色不存在"}]})
@ -190,27 +171,25 @@ async def create_user(req: CreateUserRequest, user=Depends(get_current_user), db
@router.patch("/users/{username}")
async def update_user(username: str, req: UpdateUserRequest, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""更新用户信息(状态或角色)"""
try:
_require_permission(user, "auth:manage")
result = await db.execute(select(User).where(User.username == username).limit(1))
u = result.scalars().first()
if not u:
raise HTTPException(status_code=404, detail="not_found")
updates = {}
if req.status is not None:
if req.status not in {"enabled", "disabled"}:
raise HTTPException(status_code=400, detail="invalid_status")
updates["is_active"] = _status_to_active(req.status)
if req.sort is not None:
updates["sort"] = req.sort
if req.role is not None:
if req.role not in {"admin", "operator", "observer"}:
raise HTTPException(status_code=400, detail={"errors": [{"field": "role", "message": "不允许的角色"}]})
ok = await _set_user_role(db, username, req.role)
if not ok:
raise HTTPException(status_code=400, detail={"errors": [{"field": "role", "message": "角色不存在"}]})
if updates:
updates["updated_at"] = func.now()
await db.execute(update(User).where(User.id == u.id).values(**updates))
@ -224,7 +203,6 @@ async def update_user(username: str, req: UpdateUserRequest, user=Depends(get_cu
@router.delete("/users/{username}")
async def delete_user(username: str, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""删除指定用户"""
try:
_require_permission(user, "auth:manage")
result = await db.execute(select(User).where(User.username == username).limit(1))
@ -243,13 +221,8 @@ async def delete_user(username: str, user=Depends(get_current_user), db: AsyncSe
@router.get("/users/with-roles")
async def list_users_with_roles(user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""获取所有用户及其角色信息的聚合列表"""
try:
# 此处权限检查逻辑稍有不同,要求 admin 权限
perms = user.get("permissions", []) if isinstance(user, dict) else getattr(user, "permissions", [])
if "auth:manage" not in perms:
raise HTTPException(status_code=403, detail="Permission denied")
_require_admin(user)
res = await db.execute(
text(
"SELECT u.username,u.email,u.is_active,r.role_key FROM users u LEFT JOIN user_role_mapping m ON u.id=m.user_id LEFT JOIN roles r ON r.id=m.role_id LIMIT 500"
@ -274,27 +247,27 @@ async def list_users_with_roles(user=Depends(get_current_user), db: AsyncSession
@router.patch("/user/password")
async def change_password(req: ChangePasswordRequest, user=Depends(get_current_user), db: AsyncSession = Depends(get_db)):
"""用户自行修改密码接口"""
try:
username = _get_username(user)
# 禁止修改内置演示账号的密码
# 演示账号保护
if username in {"admin", "ops", "obs"}:
raise HTTPException(status_code=400, detail="demo_user_cannot_change_password")
# 密码强度校验8-128位包含大小写和数字
# 密码强度校验
if not (8 <= len(req.newPassword) <= 128) or not re.search(r"[A-Z]", req.newPassword) or not re.search(r"[a-z]", req.newPassword) or not re.search(r"\d", req.newPassword):
raise HTTPException(status_code=400, detail="weak_new_password")
# 验证旧密码
# 查找真实用户
res = await db.execute(select(User).where(User.username == username).limit(1))
u = res.scalars().first()
if not u:
raise HTTPException(status_code=401, detail="user_not_found")
# 验证旧密码
if not bcrypt.verify(req.currentPassword, u.password_hash):
raise HTTPException(status_code=400, detail="invalid_current_password")
# 更新哈希后的新密码
# 更新密码
new_hash = bcrypt.hash(req.newPassword)
await db.execute(update(User).where(User.id == u.id).values(password_hash=new_hash, updated_at=func.now()))
await db.commit()

@ -2,38 +2,38 @@ from pydantic import BaseModel
from typing import List, Dict, Optional
class LogRequest(BaseModel):
"""日志请求模型:用于获取特定节点的日志"""
node_name: str # 节点名称 (如 hadoop102)
log_type: str # 日志类型 (如 datanode, namenode)
start_date: Optional[str] = None # 开始日期(可选)
end_date: Optional[str] = None # 结束日期(可选)
"""Log request model"""
node_name: str
log_type: str
start_date: Optional[str] = None
end_date: Optional[str] = None
class SaveLogRequest(BaseModel):
"""保存日志请求模型:将远程日志保存到本地"""
node_name: str # 节点名称
log_type: str # 日志类型
local_file_path: str # 本地保存路径
"""Save log request model"""
node_name: str
log_type: str
local_file_path: str
class LogResponse(BaseModel):
"""日志响应模型:返回单条日志内容"""
node_name: str # 节点名称
log_type: str # 日志类型
log_content: str # 日志文本内容
"""Log response model"""
node_name: str
log_type: str
log_content: str
class MultiLogResponse(BaseModel):
"""多节点日志响应模型"""
logs: Dict[str, str] # 节点名到日志内容的映射
"""Multiple logs response model"""
logs: Dict[str, str]
class SaveLogResponse(BaseModel):
"""保存日志响应模型"""
message: str # 提示信息
local_file_path: str # 最终保存的本地路径
"""Save log response model"""
message: str
local_file_path: str
class NodeListResponse(BaseModel):
"""节点列表响应模型"""
nodes: List[str] # 节点名称列表
"""Node list response model"""
nodes: List[str]
class LogFilesResponse(BaseModel):
"""日志文件列表响应模型:列出节点上的所有日志文件"""
node_name: str # 节点名称
log_files: List[str] # 文件名列表
"""Log files list response model"""
node_name: str
log_files: List[str]

@ -5,20 +5,9 @@ from ..ssh_utils import SSHClient
def collect_cluster_uuid(host: str, user: str, password: str, timeout: int | None = None) -> tuple[str | None, str | None, str | None]:
"""
通过 SSH 探测并获取远程 Hadoop 集群的 UUID
1. 建立 SSH 连接
2. 执行命令获取 NameNode 的元数据存储目录dfs.namenode.name.dir
3. 读取该目录下 current/VERSION 文件
4. 解析并提取 clusterID 字段去掉 'CID-' 前缀
返回 (cluster_id, 错误阶段标识, 详细错误信息)
"""
cli = None
try:
# 建立 SSH 客户端连接
cli = SSHClient(str(host), user or "", password or "")
# 第一阶段:探测 NameNode 目录配置
out, err = cli.execute_command_with_timeout(
"hdfs getconf -confKey dfs.namenode.name.dir",
timeout or SSH_TIMEOUT,
@ -26,15 +15,11 @@ def collect_cluster_uuid(host: str, user: str, password: str, timeout: int | Non
if not out or not out.strip():
return None, "probe_name_dirs", (err or "empty_output")
# 处理多目录情况,取第一个
name_dir = out.strip().split(",")[0]
if name_dir.startswith("file://"):
name_dir = name_dir[7:]
# 拼接 VERSION 文件路径
version_path = f"{name_dir.rstrip('/')}/current/VERSION"
# 第二阶段:读取 VERSION 文件内容
version_out, version_err = cli.execute_command_with_timeout(
f"cat {version_path}",
timeout or SSH_TIMEOUT,
@ -42,7 +27,6 @@ def collect_cluster_uuid(host: str, user: str, password: str, timeout: int | Non
if not version_out or not version_out.strip():
return None, "read_version", (version_err or "empty_output")
# 第三阶段:解析 clusterID
cluster_id = None
for line in version_out.splitlines():
if "clusterID" in line:
@ -53,15 +37,12 @@ def collect_cluster_uuid(host: str, user: str, password: str, timeout: int | Non
if not cluster_id:
return None, "parse_cluster_id", version_out.strip()
# 去除 CID- 前缀,返回标准 UUID
if cluster_id.startswith("CID-"):
cluster_id = cluster_id[4:]
return cluster_id, None, None
except Exception as e:
# 捕获连接或执行过程中的异常
return None, "connect_or_exec", str(e)
finally:
# 确保 SSH 连接关闭
try:
if cli:
cli.close()

@ -13,10 +13,6 @@ load_dotenv()
_shared_async_client: Any = None
def _get_async_client() -> Any:
"""
获取单例的异步 HTTP 客户端基于 httpx
配置了连接池限制最大 20 个保持连接最大 50 个总连接并启用 HTTP/2 支持
"""
global _shared_async_client
if httpx is None:
return None
@ -42,7 +38,6 @@ _DEFAULT_MODELS: Dict[str, str] = {
}
def _clean_str(s: str) -> str:
"""清理字符串:去除两端空白,并移除两端可能存在的引号(单引号、双引号、反引号)。"""
if s is None:
return ""
s = s.strip()
@ -51,12 +46,6 @@ def _clean_str(s: str) -> str:
return s
def _normalize_endpoint(ep: str) -> str:
"""
标准化 API 终点 URL
1. 移除末尾斜杠
2. 如果以 /v1 结尾补全为 /chat/completions
3. 如果没有 /chat/completions则添加
"""
if not ep:
return ep
s = _clean_str(ep).rstrip("/")
@ -68,41 +57,23 @@ def _normalize_endpoint(ep: str) -> str:
class LLMClient:
"""
大语言模型LLM客户端类
支持多供应商OpenAI, SiliconFlow, DeepSeek 支持流式响应和工具调用
配置通过环境变量读取
"""
def __init__(self):
# 确定供应商
self.provider = os.getenv("LLM_PROVIDER", "openai").strip().lower()
# 确定 API 终点
raw_endpoint = os.getenv("LLM_ENDPOINT", "") or _DEFAULT_ENDPOINTS.get(self.provider, _DEFAULT_ENDPOINTS["openai"])
self.endpoint = _normalize_endpoint(raw_endpoint)
# 确定模型名称
self.model = _clean_str(os.getenv("LLM_MODEL", _DEFAULT_MODELS.get(self.provider, "gpt-4o-mini")))
# 获取 API 密钥(按优先级从多个环境变量尝试)
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("DEEPSEEK_API_KEY") or os.getenv("SILICONFLOW_API_KEY") or ""
self.api_key = api_key
# 是否启用模拟模式(不实际调用接口,返回假数据)
self.simulate = os.getenv("LLM_SIMULATE", "false").lower() == "true"
# 接口请求超时时间
self.timeout = int(os.getenv("LLM_TIMEOUT", "300"))
def _headers(self) -> Dict[str, str]:
"""构造 API 请求头。"""
return {
"Authorization": f"Bearer {self.api_key}" if self.api_key else "",
"Content-Type": "application/json",
}
async def chat(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]] = None, stream: bool = False, model: Optional[str] = None) -> Any:
"""
发起对话请求
1. 检查模拟模式
2. 构造请求 Payload模型消息流式开关工具配置
3. 如果是流式响应返回一个异步生成器否则返回完整的 JSON 响应
"""
if self.simulate or httpx is None:
if stream:
async def _sim_stream():
@ -127,7 +98,6 @@ class LLMClient:
payload["tool_choice"] = "auto"
if stream:
# 处理流式响应的生成器逻辑
async def _stream_gen():
client = _get_async_client()
async with client.stream("POST", self.endpoint, headers=self._headers(), json=payload, timeout=self.timeout) as resp:
@ -144,7 +114,6 @@ class LLMClient:
continue
return _stream_gen()
# 发起非流式 POST 请求
client = _get_async_client()
resp = await client.post(self.endpoint, headers=self._headers(), json=payload, timeout=self.timeout)
resp.raise_for_status()

@ -23,18 +23,12 @@ from ..config import now_bj
def _now() -> datetime:
"""返回当前北京时间。"""
"""返回当前 UTC 时间。"""
return now_bj()
async def _find_accessible_node(db: AsyncSession, user_name: str, hostname: str) -> Optional[Node]:
"""
根据用户名和主机名查找用户有权访问的节点
1. 通过用户名查询用户 ID
2. 查询该用户关联的所有集群 ID
3. 在关联集群中查找匹配主机名的节点
"""
"""校验用户对节点的访问权限,并返回节点对象。"""
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": user_name})
uid_row = uid_res.first()
if not uid_row:
@ -48,7 +42,6 @@ async def _find_accessible_node(db: AsyncSession, user_name: str, hostname: str)
async def _user_has_cluster_access(db: AsyncSession, user_name: str, cluster_id: int) -> bool:
"""校验用户是否有权访问指定的集群。"""
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": user_name})
uid_row = uid_res.first()
if not uid_row:
@ -61,13 +54,7 @@ async def _user_has_cluster_access(db: AsyncSession, user_name: str, cluster_id:
async def _write_exec_log(db: AsyncSession, exec_id: str, command_type: str, status: str, start: datetime, end: Optional[datetime], exit_code: Optional[int], operator: str, stdout: Optional[str] = None, stderr: Optional[str] = None):
"""
将执行结果写入审计日志
1. 获取操作者的用户 ID
2. 获取操作者关联的第一个集群作为日志归属
3. 创建并保存 HadoopExecLog 记录
"""
"""写入执行审计日志。"""
# 查找 from_user_id 和 cluster_name
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": operator})
uid_row = uid_res.first()
@ -96,34 +83,21 @@ async def _write_exec_log(db: AsyncSession, exec_id: str, command_type: str, sta
async def tool_read_log(db: AsyncSession, user_name: str, node: str, path: str, lines: int = 200, pattern: Optional[str] = None, ssh_user: Optional[str] = None, timeout: int = 20) -> Dict[str, Any]:
"""
核心工具读取远程节点的日志文件
参数:
- node: 主机名
- path: 日志文件绝对路径
- lines: 读取末尾行数默认 200
- pattern: 可选的正则过滤模式
"""
"""工具:读取远端日志并可选筛选。"""
n = await _find_accessible_node(db, user_name, node)
if not n:
return {"error": "node_not_found"}
if not getattr(n, "ssh_password", None):
return {"error": "ssh_password_not_configured"}
# 转义路径和构造 tail 命令
path_q = shlex.quote(path)
cmd = f"tail -n {lines} {path_q}"
if pattern:
pat_q = shlex.quote(pattern)
cmd = f"{cmd} | grep -E {pat_q}"
start = _now()
# 使用 bash -lc 确保加载环境变量
bash_cmd = f"bash -lc {shlex.quote(cmd)}"
def _run():
"""在线程池中执行阻塞的 SSH 调用。"""
client = ssh_manager.get_connection(
str(getattr(n, "hostname", node)),
ip=str(getattr(n, "ip_address", "")),
@ -132,26 +106,17 @@ async def tool_read_log(db: AsyncSession, user_name: str, node: str, path: str,
)
return client.execute_command_with_timeout_and_status(bash_cmd, timeout=timeout)
# 异步运行阻塞的 SSH 命令
code, out, err = await asyncio.to_thread(_run)
end = _now()
exec_id = f"tool_{start.timestamp():.0f}"
# 异步写入审计日志
await _write_exec_log(db, exec_id, "read_log", ("success" if code == 0 else "failed"), start, end, code, user_name, out, err)
return {"execId": exec_id, "exitCode": code, "stdout": out, "stderr": err}
async def _fetch_page_text(client: httpx.AsyncClient, url: str) -> str:
"""
抓取并提取网页正文文本
1. 过滤非 HTTP 链接
2. 发起异步请求并设置 User-Agent
3. 使用 BeautifulSoup 解析 HTML移除无关标签脚本样式导航等
4. 返回清洗后的前 2000 个字符
"""
"""Fetch and extract text content from a URL."""
try:
# Skip if not a valid http url
if not url.startswith("http"):
return ""
@ -161,11 +126,11 @@ async def _fetch_page_text(client: httpx.AsyncClient, url: str) -> str:
resp = await client.get(url, headers=headers, follow_redirects=True)
if resp.status_code == 200:
soup = BeautifulSoup(resp.text, "html.parser")
# 移除干扰元素
# Remove scripts and styles
for script in soup(["script", "style", "nav", "footer", "header"]):
script.decompose()
text = soup.get_text(separator="\n", strip=True)
# 限制文本长度以防上下文溢出
# Limit text length
return text[:2000]
except Exception:
pass
@ -173,13 +138,7 @@ async def _fetch_page_text(client: httpx.AsyncClient, url: str) -> str:
async def tool_web_search(query: str, max_results: int = 5) -> Dict[str, Any]:
"""
核心工具通过百度进行联网搜索
1. 构造搜索请求
2. 解析搜索结果页面提取标题链接和摘要
3. 并发抓取排名前 2 的网页正文内容
"""
"""工具联网搜索Baidu并读取网页内容。"""
try:
results = []
headers = {
@ -192,12 +151,12 @@ async def tool_web_search(query: str, max_results: int = 5) -> Dict[str, Any]:
url = "https://www.baidu.com/s"
params = {"wd": query}
# 使用同步 requests 获取搜索结果(通常更稳定)
# Use sync requests for search page (stable)
resp = requests.get(url, params=params, headers=headers, timeout=10, verify=False)
if resp.status_code == 200:
soup = BeautifulSoup(resp.text, "html.parser")
# 提取百度搜索结果项
# Baidu results are usually in div with class c-container
for item in soup.select("div.c-container, div.result.c-container")[:max_results]:
title_elem = item.select_one("h3")
if not title_elem:
@ -206,20 +165,21 @@ async def tool_web_search(query: str, max_results: int = 5) -> Dict[str, Any]:
link_elem = item.select_one("a")
href = link_elem.get("href") if link_elem else ""
# 提取摘要并移除标题重复内容
# Abstract/Snippet
snippet = item.get_text(strip=True).replace(title, "")[:200]
results.append({
"title": title,
"href": href,
"body": snippet,
"full_content": "" # 待填充正文
"full_content": "" # Placeholder
})
# 并发获取前 2 条结果的全文
# Fetch full content for top 2 results
if results:
async with httpx.AsyncClient(timeout=10, verify=False) as client:
tasks = []
# Only fetch top 2 to avoid long wait
for r in results[:2]:
tasks.append(_fetch_page_text(client, r["href"]))
@ -228,9 +188,10 @@ async def tool_web_search(query: str, max_results: int = 5) -> Dict[str, Any]:
for i, content in enumerate(contents):
if content:
results[i]["full_content"] = content
# Append note to body to indicate full content is available
results[i]["body"] += "\n[Full content fetched]"
# 添加当前系统时间,协助 LLM 处理“现在”相关的查询
# Add current system time to help with "now" queries
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S %A")
return {"query": query, "current_time": current_time, "results": results}
except Exception as e:
@ -238,15 +199,7 @@ async def tool_web_search(query: str, max_results: int = 5) -> Dict[str, Any]:
async def tool_start_cluster(db: AsyncSession, user_name: str, cluster_uuid: str) -> Dict[str, Any]:
"""
核心工具启动指定的 Hadoop 集群
1. 校验集群存在性
2. 获取 SSH 登录凭据
3. NameNode 执行 start-dfs.sh
4. ResourceManager 执行 start-yarn.sh
5. 更新数据库中的集群健康状态
"""
"""工具:启动 Hadoop 集群。"""
# 1. 权限与用户
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": user_name})
uid_row = uid_res.first()
@ -266,7 +219,7 @@ async def tool_start_cluster(db: AsyncSession, user_name: str, cluster_uuid: str
start_time = _now()
logs = []
# 4. 在 NameNode 执行启动脚本
# 4. 在 NameNode 执行 start-dfs.sh
if cluster.namenode_ip and cluster.namenode_psw:
try:
def run_nn_start():
@ -277,7 +230,7 @@ async def tool_start_cluster(db: AsyncSession, user_name: str, cluster_uuid: str
except Exception as e:
logs.append(f"NameNode ({cluster.namenode_ip}) start failed: {str(e)}")
# 5. 在 ResourceManager 执行启动脚本
# 5. 在 ResourceManager 执行 start-yarn.sh
if cluster.rm_ip and cluster.rm_psw:
try:
def run_rm_start():
@ -290,7 +243,7 @@ async def tool_start_cluster(db: AsyncSession, user_name: str, cluster_uuid: str
end_time = _now()
# 6. 更新集群状态
# 6. 更新集群状态 (改进:检查是否有失败日志)
has_failed = any("failed" in log.lower() for log in logs)
if not has_failed:
cluster.health_status = "healthy"
@ -300,7 +253,7 @@ async def tool_start_cluster(db: AsyncSession, user_name: str, cluster_uuid: str
cluster.updated_at = end_time
await db.flush()
# 7. 记录审计日志
# 7. 记录日志
full_desc = " | ".join(logs)
exec_row = HadoopExecLog(
from_user_id=user_id,
@ -316,7 +269,7 @@ async def tool_start_cluster(db: AsyncSession, user_name: str, cluster_uuid: str
async def tool_stop_cluster(db: AsyncSession, user_name: str, cluster_uuid: str) -> Dict[str, Any]:
"""核心工具:停止指定的 Hadoop 集群(反向操作启动逻辑)"""
"""工具:停止 Hadoop 集群。"""
uid_res = await db.execute(text("SELECT id FROM users WHERE username=:un LIMIT 1"), {"un": user_name})
uid_row = uid_res.first()
user_id = uid_row[0] if uid_row else 1
@ -333,7 +286,6 @@ async def tool_stop_cluster(db: AsyncSession, user_name: str, cluster_uuid: str)
start_time = _now()
logs = []
# 停止 HDFS
if cluster.namenode_ip and cluster.namenode_psw:
try:
def run_nn_stop():
@ -344,7 +296,6 @@ async def tool_stop_cluster(db: AsyncSession, user_name: str, cluster_uuid: str)
except Exception as e:
logs.append(f"NameNode ({cluster.namenode_ip}) stop failed: {str(e)}")
# 停止 YARN
if cluster.rm_ip and cluster.rm_psw:
try:
def run_rm_stop():
@ -383,13 +334,7 @@ async def tool_read_cluster_log(
node_hostname: Optional[str] = None,
lines: int = 100,
) -> Dict[str, Any]:
"""
核心工具智能读取集群中特定组件的日志
1. 根据集群 UUID 查找集群
2. 确定目标 IPNameNodeResourceManager 或指定的主机
3. 通过 SSH 尝试在标准日志目录下寻找并 tail 日志文件
"""
"""读取集群中特定服务类型的日志。"""
import uuid as uuidlib
try:
uuidlib.UUID(cluster_uuid)
@ -409,7 +354,6 @@ async def tool_read_cluster_log(
ssh_user: Optional[str] = None
ssh_password: Optional[str] = None
# 根据 log_type 自动确定目标节点
if log_type.lower() == "namenode":
target_ip = str(cluster.namenode_ip) if cluster.namenode_ip else None
ssh_password = cluster.namenode_psw
@ -431,7 +375,6 @@ async def tool_read_cluster_log(
if node_obj and node_obj.ssh_user:
ssh_user = node_obj.ssh_user
# 如果通过 hostname 指定了节点
if not target_ip and target_hostname:
node = await _find_accessible_node(db, user_name, target_hostname)
if not node:
@ -446,14 +389,10 @@ async def tool_read_cluster_log(
target_hostname = target_ip
def _tail_via_ssh() -> Dict[str, Any]:
"""SSH 远程读取日志逻辑。"""
ip = str(target_ip)
hn = str(target_hostname)
# 初始化日志目录发现机制
log_reader.find_working_log_dir(hn, ip)
ssh_client = ssh_manager.get_connection(hn, ip=ip, username=ssh_user, password=ssh_password)
# 尝试匹配已知的日志路径
paths = log_reader.get_log_file_paths(hn, log_type.lower())
for p in paths:
p_q = shlex.quote(p)
@ -464,8 +403,6 @@ async def tool_read_cluster_log(
if err2:
continue
return {"status": "success", "node": hn, "log_type": log_type, "path": p, "content": out2}
# 兜底方案:在日志目录下列出文件并模糊匹配组件名
base_dir = log_reader._node_log_dir.get(hn, log_reader.log_dir)
base_q = shlex.quote(base_dir)
out, err = ssh_client.execute_command(f"ls -1 {base_q} 2>/dev/null")
@ -476,7 +413,6 @@ async def tool_read_cluster_log(
lf = f.lower()
if not f:
continue
# 匹配包含组件名且以常见日志后缀结尾的文件
if log_type.lower() in lf and hn.lower() in lf and (lf.endswith(".log") or lf.endswith(".out") or lf.endswith(".out.1")):
full = f"{base_dir}/{f}"
full_q = shlex.quote(full)
@ -488,7 +424,6 @@ async def tool_read_cluster_log(
return await asyncio.to_thread(_tail_via_ssh)
# 预定义的 Hadoop 故障规则库
_FAULT_RULES: List[Dict[str, Any]] = [
{
"id": "hdfs_safemode",
@ -557,7 +492,6 @@ _FAULT_RULES: List[Dict[str, Any]] = [
def _detect_faults_from_log_text(text: str, max_examples_per_rule: int = 3) -> List[Dict[str, Any]]:
"""根据规则库从日志文本中识别故障。"""
lines = (text or "").splitlines()
hits: List[Dict[str, Any]] = []
for rule in _FAULT_RULES:
@ -567,7 +501,6 @@ def _detect_faults_from_log_text(text: str, max_examples_per_rule: int = 3) -> L
for idx, line in enumerate(lines):
if not line:
continue
# 如果某行匹配任意正则
if any(rgx.search(line) for rgx in compiled):
examples.append({"lineNo": idx + 1, "line": line[:500]})
if len(examples) >= max_examples_per_rule:
@ -594,13 +527,6 @@ async def tool_detect_cluster_faults(
node_hostname: Optional[str] = None,
lines: int = 200,
) -> Dict[str, Any]:
"""
核心工具自动检测集群组件的常见故障
1. 依次读取指定组件默认 NameNode, ResourceManager的日志
2. 使用规则库匹配日志内容
3. 返回识别到的故障列表按严重程度排序
"""
import uuid as uuidlib
try:
@ -618,7 +544,6 @@ async def tool_detect_cluster_faults(
faults: List[Dict[str, Any]] = []
for comp in comps:
# 复用 read_cluster_log 读取内容
r = await tool_read_cluster_log(
db=db,
user_name=user_name,
@ -631,7 +556,6 @@ async def tool_detect_cluster_faults(
if r.get("status") != "success":
continue
content = r.get("content") or ""
# 执行规则检测
comp_faults = _detect_faults_from_log_text(content)
for f in comp_faults:
f2 = dict(f)
@ -640,7 +564,6 @@ async def tool_detect_cluster_faults(
f2["path"] = r.get("path")
faults.append(f2)
# 按严重程度排序 (high > medium > low)
severity_order = {"high": 0, "medium": 1, "low": 2}
faults.sort(key=lambda x: (severity_order.get((x.get("severity") or "").lower(), 9), x.get("id") or ""))
@ -653,7 +576,6 @@ async def tool_detect_cluster_faults(
}
# 预定义的运维命令白名单及其默认执行目标
_OPS_COMMANDS: Dict[str, Dict[str, Any]] = {
"jps": {"cmd": "jps -lm", "target": "all_nodes"},
"hadoop_version": {"cmd": "hadoop version", "target": "namenode"},
@ -678,14 +600,6 @@ async def tool_run_cluster_command(
timeout: int = 30,
limit_nodes: int = 20,
) -> Dict[str, Any]:
"""
核心工具在集群节点上安全执行白名单内的运维命令
参数:
- command_key: 命令标识 jps, df_h
- target: 目标namenode, resourcemanager, node, all_nodes
- node_hostname: target=node 时指定的主机名
"""
import uuid as uuidlib
try:
@ -713,7 +627,6 @@ async def tool_run_cluster_command(
bash_cmd = f"bash -lc {shlex.quote(cmd)}"
async def _exec_on_node(hostname: str, ip: str, ssh_user: Optional[str], ssh_password: Optional[str]) -> Dict[str, Any]:
"""执行 SSH 命令的闭包。"""
def _run():
client = ssh_manager.get_connection(hostname, ip=ip, username=ssh_user, password=ssh_password)
exit_code, out, err = client.execute_command_with_timeout_and_status(bash_cmd, timeout=timeout)
@ -730,7 +643,6 @@ async def tool_run_cluster_command(
results: List[Dict[str, Any]] = []
# 分发逻辑
if tgt == "namenode":
if not cluster.namenode_ip or not cluster.namenode_psw:
return {"status": "error", "message": "namenode_not_configured"}
@ -760,7 +672,6 @@ async def tool_run_cluster_command(
results.append(await _exec_on_node(node.hostname, str(node.ip_address), node.ssh_user or "hadoop", node.ssh_password))
elif tgt == "all_nodes":
# 批量在所有节点上执行(带限制)
nodes_stmt = select(Node).where(Node.cluster_id == cluster.id).limit(limit_nodes)
nodes = (await db.execute(nodes_stmt)).scalars().all()
for n in nodes:
@ -772,7 +683,6 @@ async def tool_run_cluster_command(
else:
return {"status": "error", "message": "invalid_target"}
# 记录执行日志
start = _now()
exec_id = f"tool_{start.timestamp():.0f}"
await _write_exec_log(db, exec_id, "run_cluster_command", "success", start, _now(), 0, user_name)
@ -788,7 +698,7 @@ async def tool_run_cluster_command(
def openai_tools_schema() -> List[Dict[str, Any]]:
"""返回 OpenAI 兼容的工具定义Function Calling,供 LLM 调用"""
"""返回 OpenAI 兼容的工具定义Function Calling"""
return [
{
"type": "function",

@ -5,17 +5,10 @@ from typing import Optional, Tuple
async def run_local_command(cmd: str, timeout: int = 30) -> Tuple[int, str, str]:
"""
异步运行本地 Shell 命令
1. 根据操作系统类型Windows/Linux选择合适的执行程序powershell/bash
2. 创建子进程并捕获标准输出和标准错误
3. 支持超时机制超时后会自动杀死子进程并返回 124 退出码
返回 (退出码, 标准输出, 标准错误)
"""
"""运行本地命令,返回 (exit_code, stdout, stderr)。"""
if os.name == "nt":
prog = ["powershell", "-NoProfile", "-NonInteractive", "-Command", cmd]
else:
# 使用 bash -lc 以确保加载用户的环境变量
prog = ["bash", "-lc", cmd]
proc = await asyncio.create_subprocess_exec(
*prog,
@ -34,11 +27,7 @@ async def run_local_command(cmd: str, timeout: int = 30) -> Tuple[int, str, str]
def _build_ssh_prog(host: str, user: str, cmd: str, port: Optional[int] = None, identity_file: Optional[str] = None) -> list:
"""
内部辅助函数构造 SSH 远程执行命令的参数列表
- 启用 BatchMode 以禁用交互式提示
- 禁用 StrictHostKeyChecking 以自动接受新的主机密钥
"""
"""构造 ssh 远程执行命令参数数组。"""
prog = [
"ssh",
"-o",
@ -51,19 +40,12 @@ def _build_ssh_prog(host: str, user: str, cmd: str, port: Optional[int] = None,
if identity_file:
prog += ["-i", identity_file]
target = f"{user}@{host}" if user else host
# 在远程主机上使用 bash -lc 执行命令
prog += [target, "bash", "-lc", cmd]
return prog
async def run_remote_command(host: str, user: str, cmd: str, timeout: int = 30, port: Optional[int] = None, identity_file: Optional[str] = None) -> Tuple[int, str, str]:
"""
异步运行远程 SSH 命令
1. 调用 _build_ssh_prog 构造命令参数
2. 使用 asyncio.create_subprocess_exec 启动 SSH 进程
3. 支持超时处理
返回 (退出码, 标准输出, 标准错误)
"""
"""通过 ssh 在远端主机执行命令,返回 (exit_code, stdout, stderr)。"""
prog = _build_ssh_prog(host, user, cmd, port=port, identity_file=identity_file)
proc = await asyncio.create_subprocess_exec(
*prog,

@ -2,10 +2,6 @@ from ..ssh_utils import SSHClient
from ..config import SSH_TIMEOUT
def check_ssh_connectivity(host: str, user: str, password: str, timeout: int | None = None) -> tuple[bool, str | None]:
"""
检查指定主机的 SSH 连通性
尝试执行 'echo ok' 命令如果成功返回且有输出则认为连接正常
"""
try:
cli = SSHClient(str(host), user or "", password or "")
out, _ = cli.execute_command_with_timeout("echo ok", timeout or SSH_TIMEOUT)
@ -24,38 +20,37 @@ def check_ssh_connectivity(host: str, user: str, password: str, timeout: int | N
def get_hdfs_cluster_id(host: str, user: str, password: str, timeout: int | None = None) -> tuple[str | None, str | None]:
"""
通过 SSH 远程获取 HDFS 集群的 UUID:
1. 执行 'hdfs getconf -confKey dfs.namenode.name.dir' 获取 NameNode 存储目录
2. 在该存储目录的 current/VERSION 文件中查找 clusterID
3. 解析 clusterID 并去除 'CID-' 前缀
返回 (cluster_id, error_message)
通过以下步骤获取 HDFS 集群 UUID:
1. 执行 hdfs getconf -confKey dfs.namenode.name.dir 获取名称节点目录
2. 在该目录的 current 子目录下读取 VERSION 文件
3. 解析 VERSION 文件中的 clusterID 字段
4. 去掉 'CID-' 前缀并返回
"""
try:
cli = SSHClient(str(host), user or "", password or "")
# 1. 获取 dfs.namenode.name.dir 配置项
# 1. 获取 dfs.namenode.name.dir
dir_out, dir_err = cli.execute_command_with_timeout("hdfs getconf -confKey dfs.namenode.name.dir", timeout or SSH_TIMEOUT)
if not dir_out or not dir_out.strip():
cli.close()
return None, f"Failed to get dfs.namenode.name.dir: {dir_err or 'Empty output'}"
# 处理可能存在的多个目录(以逗号分隔,取第一个)
# 处理可能存在的多个目录(取第一个)
name_dir = dir_out.strip().split(',')[0]
# 移除 file:// 前缀(如果存在)
if name_dir.startswith("file://"):
name_dir = name_dir[7:]
# 拼接 VERSION 文件的路径
version_path = f"{name_dir.rstrip('/')}/current/VERSION"
# 2. 读取 VERSION 文件内容
# 2. 读取 VERSION 文件
version_out, version_err = cli.execute_command_with_timeout(f"cat {version_path}", timeout or SSH_TIMEOUT)
cli.close()
if not version_out or not version_out.strip():
return None, f"Failed to read VERSION file at {version_path}: {version_err or 'Empty output'}"
# 3. 从文件内容中解析 clusterID 字段
# 3. 解析 clusterID
cluster_id = None
for line in version_out.splitlines():
if line.startswith("clusterID="):
@ -65,7 +60,7 @@ def get_hdfs_cluster_id(host: str, user: str, password: str, timeout: int | None
if not cluster_id:
return None, f"clusterID not found in {version_path}"
# 4. 去掉 'CID-' 前缀,返回纯 UUID 格式
# 4. 去掉 'CID-' 前缀
if cluster_id.startswith("CID-"):
cluster_id = cluster_id[4:]

@ -4,8 +4,8 @@ import paramiko
from typing import Optional, TextIO, Dict, Tuple
from .config import SSH_PORT, SSH_TIMEOUT
# 静态节点配置字典,用于所有请求。
# 避免在子进程中环境变量不可用的问题。
# Create a static node configuration dictionary that will be used for all requests
# This avoids the issue of environment variables not being available in child processes
STATIC_NODE_CONFIG = {
"hadoop102": ("192.168.10.102", "hadoop", "limouren..."),
"hadoop103": ("192.168.10.103", "hadoop", "limouren..."),
@ -14,14 +14,11 @@ STATIC_NODE_CONFIG = {
"hadoop100": ("192.168.10.100", "hadoop", "limouren...")
}
# 默认 SSH 用户名和密码,从环境变量获取
DEFAULT_SSH_USER = os.getenv("HADOOP_USER", "hadoop")
DEFAULT_SSH_PASSWORD = os.getenv("HADOOP_PASSWORD", "limouren...")
class SSHClient:
"""
SSH 客户端类封装了 paramiko用于连接远程服务器并执行操作
"""
"""SSH Client for connecting to remote servers"""
def __init__(self, hostname: str, username: str, password: str, port: int = SSH_PORT):
self.hostname = hostname
@ -31,7 +28,6 @@ class SSHClient:
self.client: Optional[paramiko.SSHClient] = None
def _ensure_connected(self) -> None:
"""确保 SSH 连接处于激活状态,如果未连接或连接断开则重新连接。"""
if self.client is None:
self.connect()
return
@ -43,12 +39,11 @@ class SSHClient:
self.connect()
def connect(self) -> None:
"""建立 SSH 连接"""
"""Establish SSH connection"""
self.client = paramiko.SSHClient()
# 自动添加远程主机的 SSH 密钥(不安全但方便,生产环境需谨慎)
# Automatically add host keys
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
sock = None
# 如果配置了 SOCKS5 代理(如 Tailscale 代理),则通过代理连接
socks5 = os.getenv("TS_SOCKS5_SERVER") or os.getenv("TAILSCALE_SOCKS5_SERVER")
if socks5:
try:
@ -65,35 +60,33 @@ class SSHClient:
)
def execute_command(self, command: str) -> tuple:
"""在远程服务器上执行命令,返回 (标准输出, 标准错误)"""
"""Execute command on remote server"""
self._ensure_connected()
stdin, stdout, stderr = self.client.exec_command(command)
return stdout.read().decode(), stderr.read().decode()
def execute_command_with_status(self, command: str) -> tuple:
"""执行命令并返回 (退出码, 标准输出, 标准错误)"""
self._ensure_connected()
stdin, stdout, stderr = self.client.exec_command(command)
exit_code = stdout.channel.recv_exit_status()
return exit_code, stdout.read().decode(), stderr.read().decode()
def execute_command_with_timeout(self, command: str, timeout: int = 30) -> tuple:
"""带超时限制执行命令"""
"""Execute command with timeout"""
self._ensure_connected()
stdin, stdout, stderr = self.client.exec_command(command, timeout=timeout)
return stdout.read().decode(), stderr.read().decode()
def execute_command_with_timeout_and_status(self, command: str, timeout: int = 30) -> tuple:
"""带超时限制执行命令并返回状态码"""
self._ensure_connected()
stdin, stdout, stderr = self.client.exec_command(command, timeout=timeout)
exit_code = stdout.channel.recv_exit_status()
return exit_code, stdout.read().decode(), stderr.read().decode()
def read_file(self, file_path: str) -> str:
"""通过 SFTP 读取远程文件内容"""
"""Read file content from remote server"""
self._ensure_connected()
with self.client.open_sftp() as sftp:
@ -101,43 +94,37 @@ class SSHClient:
return f.read().decode()
def download_file(self, remote_path: str, local_path: str) -> None:
"""将远程文件下载到本地"""
"""Download file from remote server to local"""
self._ensure_connected()
with self.client.open_sftp() as sftp:
sftp.get(remote_path, local_path)
def close(self) -> None:
"""关闭 SSH 连接"""
"""Close SSH connection"""
if self.client:
self.client.close()
self.client = None
def __enter__(self):
"""支持上下文管理器语法 (with SSHClient(...) as client)"""
"""Context manager entry"""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""退出上下文管理器时自动关闭连接"""
"""Context manager exit"""
self.close()
class SSHConnectionManager:
"""
SSH 连接管理器缓存并复用多个节点的 SSH 连接
"""
"""SSH Connection Manager for managing multiple SSH connections"""
def __init__(self):
self.connections = {}
def get_connection(self, node_name: str, ip: str = None, username: str = None, password: str = None) -> SSHClient:
"""
获取或创建指定节点的 SSH 连接
如果参数发生变化 IP 改变会关闭旧连接并创建新连接
"""
"""Get or create SSH connection for a node"""
if node_name in self.connections:
client = self.connections[node_name]
# 检查现有连接的参数是否匹配,如果不匹配则销毁
if ip and getattr(client, "hostname", None) != ip:
try:
client.close()
@ -157,7 +144,6 @@ class SSHConnectionManager:
pass
del self.connections[node_name]
# 如果没有缓存的连接,则创建新的
if node_name not in self.connections:
if not ip:
raise ValueError(f"IP address required for new connection to {node_name}")
@ -171,18 +157,20 @@ class SSHConnectionManager:
return self.connections[node_name]
def close_all(self) -> None:
"""关闭管理器中的所有 SSH 连接"""
"""Close all SSH connections"""
for conn in self.connections.values():
conn.close()
self.connections.clear()
def __enter__(self):
"""Context manager entry"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
self.close_all()
# 创建全局 SSH 连接管理器实例
# Create a global SSH connection manager instance
ssh_manager = SSHConnectionManager()

@ -7,50 +7,31 @@ from app.models.clusters import Cluster
from app.metrics_collector import metrics_collector
async def collect_once(cluster_uuid: str):
"""
单次采集逻辑
1. 根据集群 UUID 查找集群 ID
2. 获取该集群下所有节点的信息
3. 遍历节点调用 metrics_collector 获取 CPU 和内存数据并保存
"""
async with SessionLocal() as session:
cid_res = await session.execute(select(Cluster.id).where(Cluster.uuid == cluster_uuid).limit(1))
cid = cid_res.scalars().first()
if not cid:
return
# 查询集群关联的所有节点
res = await session.execute(select(Node.id, Node.hostname, Node.ip_address).where(Node.cluster_id == cid))
rows = res.all()
for nid, hn, ip in rows:
# 读取节点的 CPU 和内存使用率
cpu, mem = metrics_collector._read_cpu_mem(hn, str(ip))
# 将指标数据保存到数据库/内存缓存中
await metrics_collector._save_metrics(nid, hn, cid, cpu, mem)
async def runner(cluster_uuid: str, interval: int):
"""
采集运行器无限循环执行采集任务并根据设定间隔休眠
"""
while True:
try:
await collect_once(cluster_uuid)
except Exception:
# 采集过程中发生错误时,记录日志或忽略,确保循环不中断
pass
await asyncio.sleep(interval)
def main():
"""
Worker 进程入口解析命令行参数并启动采集任务
"""
parser = argparse.ArgumentParser()
parser.add_argument("--cluster", required=True, help="要采集指标的集群 UUID")
parser.add_argument("--interval", type=int, default=3, help="采集间隔秒数 (默认 3s)")
parser.add_argument("--cluster", required=True, help="Cluster UUID to collect metrics for")
parser.add_argument("--interval", type=int, default=3, help="Collect interval seconds")
args = parser.parse_args()
# 设置全局采集间隔配置
metrics_collector.set_collection_interval(args.interval)
# 启动异步事件循环运行采集任务
asyncio.run(runner(args.cluster, args.interval))
if __name__ == "__main__":

Loading…
Cancel
Save