Compare commits

...

1 Commits

Author SHA1 Message Date
LYQ 32dcddce65 更新了新功能
2 months ago

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CheckStyle-IDEA" serialisationVersion="2">
<checkstyleVersion>11.0.1</checkstyleVersion>
<scanScope>JavaOnly</scanScope>
<copyLibs>true</copyLibs>
<option name="thirdPartyClasspath" />
<option name="activeLocationIds" />
<option name="locations">
<list>
<ConfigurationLocation id="bundled-sun-checks" type="BUNDLED" scope="All" description="Sun Checks">(bundled)</ConfigurationLocation>
<ConfigurationLocation id="bundled-google-checks" type="BUNDLED" scope="All" description="Google Checks">(bundled)</ConfigurationLocation>
</list>
</option>
</component>
</project>

@ -0,0 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CompilerConfiguration">
<annotationProcessing>
<profile name="Annotation profile for springboot_demo" enabled="true">
<sourceOutputDir name="target/generated-sources/annotations" />
<sourceTestOutputDir name="target/generated-test-sources/test-annotations" />
<outputRelativeToContentRoot value="true" />
<processorPath useClasspath="false">
<entry name="$MAVEN_REPOSITORY$/org/projectlombok/lombok/1.18.42/lombok-1.18.42.jar" />
</processorPath>
<module name="springboot_demo" />
</profile>
</annotationProcessing>
</component>
<component name="JavacSettings">
<option name="ADDITIONAL_OPTIONS_OVERRIDE">
<module name="springboot_demo" options="-parameters" />
</option>
</component>
</project>

@ -0,0 +1,37 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="natural_language_query_system@localhost" uuid="b8d27d04-835f-4e42-bccb-3acb57e78840">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<imported>true</imported>
<remarks>$PROJECT_DIR$/test/src/main/resources/application.yml</remarks>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://localhost:3306/natural_language_query_system?useUnicode=true&amp;characterEncoding=utf8&amp;useSSL=false&amp;serverTimezone=Asia/Shanghai&amp;allowPublicKeyRetrieval=true</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="natural_language_query_system@127.0.0.1" uuid="da4c9031-0352-480e-ad66-63e86248d59e">
<driver-ref>mongo</driver-ref>
<synchronize>true</synchronize>
<imported>true</imported>
<remarks>$PROJECT_DIR$/test/src/main/resources/application.yml</remarks>
<jdbc-driver>com.dbschema.MongoJdbcDriver</jdbc-driver>
<jdbc-url>mongodb://127.0.0.1:27017/natural_language_query_system</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="0@localhost" uuid="945204ff-b09b-4daa-846b-6ab000c85369">
<driver-ref>redis</driver-ref>
<synchronize>true</synchronize>
<imported>true</imported>
<remarks>$PROJECT_DIR$/test/src/main/resources/application.yml</remarks>
<jdbc-driver>jdbc.RedisDriver</jdbc-driver>
<jdbc-url>jdbc:redis://localhost:6379/0</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding">
<file url="file://$PROJECT_DIR$/test/src/main/java" charset="UTF-8" />
</component>
</project>

@ -0,0 +1,5 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="PROJECT_PROFILE" />
</settings>
</component>

@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="RemoteRepositoriesConfiguration">
<remote-repository>
<option name="id" value="central" />
<option name="name" value="Maven Central repository" />
<option name="url" value="https://repo1.maven.org/maven2" />
</remote-repository>
<remote-repository>
<option name="id" value="jboss.community" />
<option name="name" value="JBoss Community repository" />
<option name="url" value="https://repository.jboss.org/nexus/content/repositories/public/" />
</remote-repository>
<remote-repository>
<option name="id" value="central" />
<option name="name" value="Central Repository" />
<option name="url" value="https://maven.aliyun.com/repository/public" />
</remote-repository>
</component>
</project>

@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="MavenProjectsManager">
<option name="originalFiles">
<list>
<option value="$PROJECT_DIR$/test/pom.xml" />
</list>
</option>
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_X" default="true" project-jdk-name="19.0" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/src.iml" filepath="$PROJECT_DIR$/.idea/src.iml" />
</modules>
</component>
</project>

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

@ -1,6 +1,7 @@
import React, { useState } from 'react';
import { UserRole } from '../types';
import { authApi, LoginRequest } from '../services/api';
import { logOperation, LogModule, LogOperationType, LogStatus } from '../utils/logger';
interface LoginPageProps {
onLogin: (role: UserRole) => void;
@ -55,10 +56,16 @@ export const LoginPage: React.FC<LoginPageProps> = ({ onLogin }) => {
// 保存用户角色到sessionStorage
sessionStorage.setItem('userRole', role);
// 记录登录成功日志
await logOperation(LogModule.SYSTEM, LogOperationType.LOGIN, `用户 ${response.username} 登录系统`, LogStatus.SUCCESS);
// 刷新页面以确保所有组件状态都被清理(处理切换账户的情况)
window.location.reload();
} catch (err) {
setError(err instanceof Error ? err.message : '登录失败,请检查用户名和密码');
const errorMessage = err instanceof Error ? err.message : '登录失败,请检查用户名和密码';
setError(errorMessage);
// 记录登录失败日志(不阻塞用户体验)
await logOperation(LogModule.SYSTEM, LogOperationType.LOGIN, `用户 ${username} 登录失败:${errorMessage}`, LogStatus.FAILURE);
} finally {
setIsLoading(false);
}

@ -5,7 +5,7 @@ import { Dropdown } from './Dropdown';
import { DATABASE_OPTIONS } from '../constants';
import { HistorySidebar } from './HistorySidebar';
import { RightSidebar } from './RightSidebar';
import { queryApi, QueryResponse, llmConfigApi } from '../services/api';
import { queryApi, QueryResponse, llmConfigApi, dbConnectionApi } from '../services/api';
interface QueryPageProps {
currentConversation: Conversation | undefined;
@ -43,7 +43,9 @@ export const QueryPage: React.FC<QueryPageProps> = ({
const [prompt, setPrompt] = useState('');
const [modelOptions, setModelOptions] = useState<Array<{id: string, name: string, disabled: boolean, description: string}>>([]);
const [selectedModelId, setSelectedModelId] = useState('');
const [selectedDatabase, setSelectedDatabase] = useState(DATABASE_OPTIONS[0].name);
const [databaseOptions, setDatabaseOptions] = useState<Array<{id: string, name: string, disabled: boolean, description: string}>>([]);
const [selectedDatabaseId, setSelectedDatabaseId] = useState('');
const [selectedDatabase, setSelectedDatabase] = useState(''); // 保留用于显示
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const [abortController, setAbortController] = useState<AbortController | null>(null);
@ -51,9 +53,10 @@ export const QueryPage: React.FC<QueryPageProps> = ({
const [pendingConversationId, setPendingConversationId] = useState<string | null>(null);
const chatContainerRef = useRef<HTMLDivElement>(null);
// 从后端加载可用的大模型配置
// 从后端加载可用的大模型配置和数据库连接
useEffect(() => {
loadAvailableModels();
loadDatabaseConnections();
}, []);
const loadAvailableModels = async () => {
@ -77,6 +80,30 @@ export const QueryPage: React.FC<QueryPageProps> = ({
}
};
const loadDatabaseConnections = async () => {
try {
const connections = await dbConnectionApi.getList();
// 只显示未禁用的连接
const activeConnections = connections.filter(conn => conn.status !== 'disabled');
const options = activeConnections.map(conn => ({
id: String(conn.id),
name: conn.name,
disabled: false,
description: `${conn.name} - ${conn.url}`,
}));
setDatabaseOptions(options);
if (options.length > 0) {
setSelectedDatabaseId(options[0].id);
setSelectedDatabase(options[0].name);
}
} catch (error) {
console.error('加载数据库连接失败:', error);
// 如果加载失败,使用默认选项
setDatabaseOptions([]);
setError('无法加载数据库连接,请联系管理员');
}
};
// 自动滚动到底部
useEffect(() => {
if (chatContainerRef.current) {
@ -123,11 +150,12 @@ export const QueryPage: React.FC<QueryPageProps> = ({
try {
if (!currentConversation) throw new Error("No active conversation.");
// 3. 调用后端API传递模型配置ID
// 3. 调用后端API传递模型配置ID和数据库连接ID
const response: QueryResponse = await queryApi.execute({
userPrompt: finalPrompt,
model: selectedModelId, // 传递模型配置ID而不是名称
database: selectedDatabase,
model: selectedModelId, // 传递模型配置ID
database: selectedDatabase, // 数据库名称用于显示和LLM上下文
dbConnectionId: Number(selectedDatabaseId), // 数据库连接ID用于实际连接
conversationId: currentConversation.id !== 'conv-1' ? currentConversation.id : undefined,
});
@ -242,7 +270,20 @@ export const QueryPage: React.FC<QueryPageProps> = ({
}}
icon="fa-cogs"
/>
<Dropdown options={DATABASE_OPTIONS} selected={selectedDatabase} setSelected={setSelectedDatabase} icon="fa-database" />
<Dropdown
options={databaseOptions.length > 0 ? databaseOptions : DATABASE_OPTIONS}
selected={selectedDatabase}
setSelected={(name) => {
const option = databaseOptions.find(opt => opt.name === name);
if (option) {
setSelectedDatabaseId(option.id);
setSelectedDatabase(option.name);
} else {
setSelectedDatabase(name);
}
}}
icon="fa-database"
/>
</div>
<form onSubmit={handleSubmit} className="relative">
<textarea

@ -24,7 +24,7 @@ const defaultShowActions = { save: true, share: true, export: true };
export const QueryResult: React.FC<QueryResultProps> = ({ result, onSaveQuery, onShareQuery, savedQueries = [], showActions = defaultShowActions }) => {
const [activeView, setActiveView] = useState<ViewType>('table');
const [chartType, setChartType] = useState<ChartType>(result.chartData.type);
const [chartType, setChartType] = useState<ChartType>(result.chartData?.type || 'bar');
const [activeModal, setActiveModal] = useState<string | null>(null);
const [copySuccess, setCopySuccess] = useState(false);
const [selectedFriendId, setSelectedFriendId] = useState<string | null>(null);
@ -33,11 +33,13 @@ export const QueryResult: React.FC<QueryResultProps> = ({ result, onSaveQuery, o
const chartRef = useRef<Chart | null>(null);
useEffect(() => {
setChartType(result.chartData.type);
}, [result.chartData.type]);
if (result.chartData?.type) {
setChartType(result.chartData.type);
}
}, [result.chartData?.type]);
useEffect(() => {
if (activeView === 'chart' && canvasRef.current) {
if (activeView === 'chart' && canvasRef.current && result.chartData) {
if (chartRef.current) {
chartRef.current.destroy();
}
@ -140,6 +142,17 @@ export const QueryResult: React.FC<QueryResultProps> = ({ result, onSaveQuery, o
</div>
);
case 'chart':
// 如果没有图表数据,显示提示信息
if (!result.chartData || !result.chartData.labels || result.chartData.labels.length === 0) {
return (
<div className="flex flex-col items-center justify-center py-12 text-gray-500">
<i className="fa fa-chart-bar text-6xl mb-4 opacity-20"></i>
<p className="text-lg"></p>
<p className="text-sm mt-2"></p>
</div>
);
}
const ChartTypeButton: React.FC<{ type: ChartType; label: string; icon: string; }> = ({ type, label, icon }) => (
<button
onClick={() => setChartType(type)}

@ -1,6 +1,6 @@
import React, { useState, useRef, useEffect } from 'react';
import { AdminModal } from './AdminModal';
import { userApi } from '../../services/api';
import { userApi, ChangePasswordRequest } from '../../services/api';
export const AdminAccountPage: React.FC = () => {
const userId = Number(sessionStorage.getItem('userId') || '1');
@ -82,9 +82,24 @@ export const AdminAccountPage: React.FC = () => {
alert('新密码长度不能少于6位');
return;
}
alert('密码修改功能待后端接口实现');
setModal(null);
setPasswordForm({ oldPassword: '', newPassword: '', confirmPassword: '' });
if (!passwordForm.oldPassword) {
alert('请输入当前密码!');
return;
}
try {
const changePasswordRequest: ChangePasswordRequest = {
userId: userId,
oldPassword: passwordForm.oldPassword,
newPassword: passwordForm.newPassword,
};
await userApi.changePassword(changePasswordRequest);
alert('密码修改成功!');
setModal(null);
setPasswordForm({ oldPassword: '', newPassword: '', confirmPassword: '' });
} catch (error) {
console.error('修改密码失败:', error);
alert(error instanceof Error ? error.message : '修改失败,请检查当前密码是否正确');
}
};
const handleAvatarChange = async (e: React.ChangeEvent<HTMLInputElement>) => {

@ -1,7 +1,7 @@
import React, { useState, useMemo, useEffect } from 'react';
import { AdminModal } from './AdminModal';
import { SystemLog } from '../../types';
import { MOCK_SYSTEM_LOGS } from '../../constants';
import { operationLogApi, userApi } from '../../services/api';
interface SystemLogPageProps {
initialStatusFilter: string;
@ -9,18 +9,53 @@ interface SystemLogPageProps {
}
export const SystemLogPage: React.FC<SystemLogPageProps> = ({ initialStatusFilter, clearInitialFilter }) => {
const [logs, setLogs] = useState<SystemLog[]>(MOCK_SYSTEM_LOGS);
const [logs, setLogs] = useState<SystemLog[]>([]);
const [loading, setLoading] = useState(true);
const [filters, setFilters] = useState({ startDate: '', endDate: '', user: '', action: '', status: initialStatusFilter });
const [isExportModalOpen, setExportModalOpen] = useState(false);
const [viewingLog, setViewingLog] = useState<SystemLog | null>(null);
useEffect(() => {
loadLogs();
// Clear the initial filter from the parent so it's not reapplied on re-renders
if (initialStatusFilter) {
clearInitialFilter();
}
}, [initialStatusFilter, clearInitialFilter]);
const loadLogs = async () => {
try {
setLoading(true);
// 并行加载日志和用户数据
const [logsData, usersData] = await Promise.all([
operationLogApi.getList(),
userApi.getList()
]);
const userMap = new Map(usersData.map(u => [u.id, u.username]));
const systemLogs: SystemLog[] = logsData.map(log => ({
id: `#LOG${log.id}`,
time: log.operateTime.replace('T', ' '),
user: userMap.get(log.userId) || `用户#${log.userId}`,
action: log.operateDesc,
model: log.module,
ip: log.ip,
status: log.status === 1 ? 'success' : 'failure',
details: log.operateType
}));
// 按时间倒序排序
systemLogs.sort((a, b) => new Date(b.time).getTime() - new Date(a.time).getTime());
setLogs(systemLogs);
} catch (error) {
console.error('加载系统日志失败:', error);
} finally {
setLoading(false);
}
};
const filteredLogs = useMemo(() => {
return logs.filter(log =>
(!filters.startDate || log.time >= filters.startDate) &&

@ -2,6 +2,7 @@ import React, { useState, useMemo, useCallback, useEffect } from 'react';
import { AdminModal } from './AdminModal';
import { AdminUser, UserRole } from '../../types';
import { userApi, User } from '../../services/api';
import { logOperation, LogModule, LogOperationType, LogStatus } from '../../utils/logger';
export const UserManagementPage: React.FC = () => {
const [users, setUsers] = useState<AdminUser[]>([]);
@ -92,10 +93,12 @@ export const UserManagementPage: React.FC = () => {
if (userToProcess) {
try {
await userApi.delete(userToProcess.id);
await logOperation(LogModule.USER_MANAGEMENT, LogOperationType.DELETE, `删除用户:${userToProcess.username}`, LogStatus.SUCCESS);
setUsers(prev => prev.filter(u => u.id !== userToProcess.id));
alert('删除成功');
} catch (error) {
console.error('删除用户失败:', error);
await logOperation(LogModule.USER_MANAGEMENT, LogOperationType.DELETE, `删除用户失败:${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('删除失败: ' + (error instanceof Error ? error.message : '未知错误'));
}
}
@ -111,14 +114,18 @@ export const UserManagementPage: React.FC = () => {
if(userToProcess) {
try {
const newStatus = userToProcess.status === 'active' ? 0 : 1;
const operationType = userToProcess.status === 'active' ? LogOperationType.DISABLE : LogOperationType.ENABLE;
await userApi.update({
id: userToProcess.id,
status: newStatus,
});
await logOperation(LogModule.USER_MANAGEMENT, operationType, `${operationType}用户:${userToProcess.username}`, LogStatus.SUCCESS);
setUsers(prev => prev.map(u => u.id === userToProcess.id ? { ...u, status: u.status === 'active' ? 'disabled' : 'active' } : u));
alert('操作成功');
} catch (error) {
console.error('更新用户状态失败:', error);
const operationType = userToProcess.status === 'active' ? LogOperationType.DISABLE : LogOperationType.ENABLE;
await logOperation(LogModule.USER_MANAGEMENT, operationType, `${operationType}用户失败:${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('操作失败: ' + (error instanceof Error ? error.message : '未知错误'));
}
}

@ -1,16 +1,51 @@
import React, { useState, useMemo } from 'react';
import React, { useState, useMemo, useEffect } from 'react';
import { AdminModal } from '../admin/AdminModal';
import { ConnectionLog } from '../../types';
import { MOCK_CONNECTION_LOGS } from '../../constants';
import { dbConnectionLogApi, dbConnectionApi } from '../../services/api';
export const ConnectionLogPage: React.FC = () => {
const [logs, setLogs] = useState<ConnectionLog[]>(MOCK_CONNECTION_LOGS);
const [logs, setLogs] = useState<ConnectionLog[]>([]);
const [loading, setLoading] = useState(true);
const [isExportModalOpen, setExportModalOpen] = useState(false);
const [viewingLog, setViewingLog] = useState<ConnectionLog | null>(null);
// 新增搜索相关状态
const [searchTerm, setSearchTerm] = useState('');
const [searchType, setSearchType] = useState<'all' | 'time' | 'datasource' | 'status'>('all');
// 加载连接日志数据
useEffect(() => {
loadConnectionLogs();
}, []);
const loadConnectionLogs = async () => {
try {
setLoading(true);
const [logsData, connections] = await Promise.all([
dbConnectionLogApi.getList(),
dbConnectionApi.getList()
]);
const connectionMap = new Map(connections.map(c => [c.id, c.name]));
const formattedLogs: ConnectionLog[] = logsData.map(log => ({
id: String(log.id),
time: log.connectTime,
datasource: connectionMap.get(log.dbConnectionId) || `数据源#${log.dbConnectionId}`,
status: log.status === 1 ? '成功' : '失败',
details: log.errorMessage || undefined
}));
// 按时间倒序排序
formattedLogs.sort((a, b) => new Date(b.time).getTime() - new Date(a.time).getTime());
setLogs(formattedLogs);
} catch (error) {
console.error('加载连接日志失败:', error);
} finally {
setLoading(false);
}
};
const getStatusClass = (status: '成功' | '失败') => {
return status === '成功' ? 'text-success' : 'text-danger';
};
@ -43,6 +78,19 @@ export const ConnectionLogPage: React.FC = () => {
});
}, [logs, searchTerm, searchType]);
if (loading) {
return (
<main className="flex-1 overflow-y-auto p-6 space-y-6">
<div className="flex items-center justify-center h-64">
<div className="text-center">
<i className="fa fa-spinner fa-spin text-3xl text-primary mb-4"></i>
<p className="text-gray-500">...</p>
</div>
</div>
</main>
);
}
return (
<main className="flex-1 overflow-y-auto p-6 space-y-6">
{/* 修改顶部区域,添加搜索功能 */}

@ -2,6 +2,7 @@ import React, { useState, useMemo, useEffect } from 'react';
import { AdminModal } from '../admin/AdminModal';
import { DataSource } from '../../types';
import { dbConnectionApi, DbConnection } from '../../services/api';
import { logOperation, LogModule, LogOperationType, LogStatus } from '../../utils/logger';
export const DataSourceManagementPage: React.FC = () => {
const [dataSources, setDataSources] = useState<DataSource[]>([]);
@ -110,6 +111,10 @@ export const DataSourceManagementPage: React.FC = () => {
alert('数据库端口不能为空');
return;
}
if (!data.database || !(data.database as string).trim()) {
alert('数据库名称不能为空');
return;
}
if (!data.username || !(data.username as string).trim()) {
alert('数据库账号不能为空');
return;
@ -124,13 +129,14 @@ export const DataSourceManagementPage: React.FC = () => {
const backendConnection: Partial<DbConnection> = {
name: (data.name as string).trim(),
dbTypeId: mapTypeToDbTypeId(data.type as DataSource['type']),
url: `${(data.host as string).trim()}:${(data.port as string).trim()}`,
url: `${(data.host as string).trim()}:${(data.port as string).trim()}/${(data.database as string).trim()}`,
username: (data.username as string).trim(),
password: (data.password as string).trim(),
status: 'disconnected',
createUserId: Number(sessionStorage.getItem('userId') || '1'),
};
await dbConnectionApi.create(backendConnection);
await logOperation(LogModule.DATA_SOURCE, LogOperationType.CREATE, `创建数据源:${backendConnection.name}`, LogStatus.SUCCESS);
alert('添加成功');
await loadDataSources();
} else if (modal === 'edit' && currentItem) {
@ -138,14 +144,17 @@ export const DataSourceManagementPage: React.FC = () => {
id: Number(currentItem.id),
name: data.name as string,
dbTypeId: mapTypeToDbTypeId(data.type as DataSource['type']),
url: `${data.host}:${data.port}`,
url: `${data.host}:${data.port}/${data.database}`,
};
await dbConnectionApi.update(backendConnection);
await logOperation(LogModule.DATA_SOURCE, LogOperationType.UPDATE, `更新数据源:${backendConnection.name}`, LogStatus.SUCCESS);
alert('更新成功');
await loadDataSources();
}
} catch (error) {
console.error('保存数据源失败:', error);
const operationType = modal === 'add' ? LogOperationType.CREATE : LogOperationType.UPDATE;
await logOperation(LogModule.DATA_SOURCE, operationType, `保存数据源失败:${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('保存失败: ' + (error instanceof Error ? error.message : '未知错误'));
return;
}
@ -156,10 +165,12 @@ export const DataSourceManagementPage: React.FC = () => {
if (currentItem) {
try {
await dbConnectionApi.delete(Number(currentItem.id));
await logOperation(LogModule.DATA_SOURCE, LogOperationType.DELETE, `删除数据源:${currentItem.name}`, LogStatus.SUCCESS);
alert('删除成功');
await loadDataSources();
} catch (error) {
console.error('删除数据源失败:', error);
await logOperation(LogModule.DATA_SOURCE, LogOperationType.DELETE, `删除数据源失败:${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('删除失败: ' + (error instanceof Error ? error.message : '未知错误'));
}
}
@ -167,6 +178,7 @@ export const DataSourceManagementPage: React.FC = () => {
};
const handleTestConnection = async (id: string) => {
const dataSource = dataSources.find(ds => ds.id === id);
setDataSources(prev => prev.map(ds => ds.id === id ? { ...ds, status: 'testing' } : ds));
try {
const result = await dbConnectionApi.test(Number(id));
@ -177,13 +189,16 @@ export const DataSourceManagementPage: React.FC = () => {
return ds;
}));
if (result) {
await logOperation(LogModule.DATA_SOURCE, LogOperationType.TEST, `测试数据源连接成功:${dataSource?.name}`, LogStatus.SUCCESS);
alert('连接测试成功');
} else {
await logOperation(LogModule.DATA_SOURCE, LogOperationType.TEST, `测试数据源连接失败:${dataSource?.name}`, LogStatus.FAILURE);
alert('连接测试失败');
}
} catch (error) {
console.error('测试连接失败:', error);
setDataSources(prev => prev.map(ds => ds.id === id ? { ...ds, status: 'error' } : ds));
await logOperation(LogModule.DATA_SOURCE, LogOperationType.TEST, `测试数据源连接异常:${dataSource?.name} - ${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('测试连接失败: ' + (error instanceof Error ? error.message : '未知错误'));
}
};
@ -197,14 +212,18 @@ export const DataSourceManagementPage: React.FC = () => {
if (currentItem) {
try {
const newStatus = currentItem.status === 'disabled' ? 'disconnected' : 'disabled';
const operationType = currentItem.status === 'disabled' ? LogOperationType.ENABLE : LogOperationType.DISABLE;
await dbConnectionApi.update({
id: Number(currentItem.id),
status: newStatus,
});
await logOperation(LogModule.DATA_SOURCE, operationType, `${operationType}数据源:${currentItem.name}`, LogStatus.SUCCESS);
alert('操作成功');
await loadDataSources();
} catch (error) {
console.error('切换数据源状态失败:', error);
const operationType = currentItem.status === 'disabled' ? LogOperationType.ENABLE : LogOperationType.DISABLE;
await logOperation(LogModule.DATA_SOURCE, operationType, `${operationType}数据源失败:${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('操作失败: ' + (error instanceof Error ? error.message : '未知错误'));
}
}
@ -279,8 +298,9 @@ export const DataSourceManagementPage: React.FC = () => {
<form onSubmit={handleSave} className="space-y-4">
<div><label className="block text-sm mb-1"></label><input name="name" defaultValue={currentItem?.name} required className="w-full px-3 py-2 border rounded-lg"/></div>
<div><label className="block text-sm mb-1"></label><select name="type" defaultValue={currentItem?.type} className="w-full px-3 py-2 border rounded-lg"><option value="MySQL">MySQL</option><option value="PostgreSQL">PostgreSQL</option><option value="Oracle">Oracle</option><option value="SQL Server">SQL Server</option></select></div>
<div><label className="block text-sm mb-1"></label><input name="host" defaultValue={currentItem?.address.split(':')[0]} required className="w-full px-3 py-2 border rounded-lg"/></div>
<div><label className="block text-sm mb-1"></label><input name="port" defaultValue={currentItem?.address.split(':')[1]} required className="w-full px-3 py-2 border rounded-lg"/></div>
<div><label className="block text-sm mb-1">IP</label><input name="host" defaultValue={currentItem?.address.split(':')[0].split('/')[0]} required className="w-full px-3 py-2 border rounded-lg" placeholder="例如localhost 或 192.168.1.100"/></div>
<div><label className="block text-sm mb-1"></label><input name="port" defaultValue={currentItem?.address.split(':')[1]?.split('/')[0]} required className="w-full px-3 py-2 border rounded-lg" placeholder="例如3306"/></div>
<div><label className="block text-sm mb-1"></label><input name="database" defaultValue={currentItem?.address.split('/')[1] || ''} required className="w-full px-3 py-2 border rounded-lg" placeholder="例如natural_language_query_system"/></div>
<div><label className="block text-sm mb-1"></label><input name="username" type="text" required className="w-full px-3 py-2 border rounded-lg"/></div>
<div><label className="block text-sm mb-1"></label><input name="password" type="password" required className="w-full px-3 py-2 border rounded-lg"/></div>
<div className="flex justify-end space-x-2 pt-4"><button type="button" onClick={() => setModal(null)} className="px-4 py-2 border rounded-lg"></button><button type="submit" className="px-4 py-2 bg-primary text-white rounded-lg"></button></div>

@ -2,6 +2,7 @@ import React, { useState, useMemo, useEffect } from 'react';
import { AdminModal } from '../admin/AdminModal';
import { UserPermissionAssignment, UnassignedUser, DataSourcePermission } from '../../types';
import { userDbPermissionApi, UserDbPermission, userApi, User, dbConnectionApi, DbConnection } from '../../services/api';
import { logOperation, LogModule, LogOperationType, LogStatus } from '../../utils/logger';
export const UserPermissionPage: React.FC = () => {
const [unassignedUsers, setUnassignedUsers] = useState<UnassignedUser[]>([]);
@ -161,16 +162,20 @@ type SearchCategory = 'all' | 'username' | 'email' | 'datasource' | 'table';
const currentUserId = Number(sessionStorage.getItem('userId') || '1');
const permissionDetails = filteredPerms.map(p => ({
db_connection_id: Number(p.dataSourceId),
table_ids: p.tables.map(t => Number(t.replace('table_', ''))),
// 存储表名而不是 ID因为目前没有 table_metadata ID
table_names: p.tables,
table_ids: [] // 保持兼容性,传空数组
}));
for (const userId of userIds) {
const user = usersToAssign.find(u => u.id === userId);
await userDbPermissionApi.create({
userId: Number(userId),
permissionDetails: JSON.stringify(permissionDetails),
isAssigned: 1,
lastGrantUserId: currentUserId,
});
await logOperation(LogModule.PERMISSION, LogOperationType.ASSIGN, `为用户 ${user?.username} 分配数据权限`, LogStatus.SUCCESS);
}
alert('分配权限成功');
await loadData();
@ -178,18 +183,23 @@ type SearchCategory = 'all' | 'username' | 'email' | 'datasource' | 'table';
// 更新权限
const permissionDetails = filteredPerms.map(p => ({
db_connection_id: Number(p.dataSourceId),
table_ids: p.tables.map(t => Number(t.replace('table_', ''))),
// 存储表名
table_names: p.tables,
table_ids: []
}));
await userDbPermissionApi.update({
id: Number(currentItem.id),
permissionDetails: JSON.stringify(permissionDetails),
lastGrantUserId: Number(sessionStorage.getItem('userId') || '1'),
});
await logOperation(LogModule.PERMISSION, LogOperationType.UPDATE, `更新用户 ${currentItem.username} 的数据权限`, LogStatus.SUCCESS);
alert('更新权限成功');
await loadData();
}
} catch (error) {
console.error('保存权限失败:', error);
const operationType = modal === 'assign' ? LogOperationType.ASSIGN : LogOperationType.UPDATE;
await logOperation(LogModule.PERMISSION, operationType, `保存权限失败:${error instanceof Error ? error.message : '未知错误'}`, LogStatus.FAILURE);
alert('保存权限失败: ' + (error instanceof Error ? error.message : '未知错误'));
return;
}
@ -226,45 +236,68 @@ type SearchCategory = 'all' | 'username' | 'email' | 'datasource' | 'table';
};
const handleSelectAllTables = (dsId: string, checked: boolean) => {
// 简化处理:假设每个数据源有默认的表列表
// 实际应该从后端获取表列表
const defaultTables = ['table_1', 'table_2', 'table_3']; // 临时处理
setPerms(prev => prev.map(p => p.dataSourceId === dsId ? { ...p, tables: checked ? defaultTables : [] } : p));
const currentTables = availableTables[dsId] || [];
setPerms(prev => prev.map(p => p.dataSourceId === dsId ? { ...p, tables: checked ? currentTables : [] } : p));
};
const title = users.length > 1 ? `${users.length} 位用户分配权限` : `${users[0].username} 分配权限`;
const isEditing = existingPermissions.length > 0;
// 加载表列表
const [availableTables, setAvailableTables] = useState<Record<string, string[]>>({});
useEffect(() => {
const loadTables = async () => {
const tablesMap: Record<string, string[]> = {};
for (const ds of dataSources) {
try {
const tables = await dbConnectionApi.getTables(Number(ds.id));
tablesMap[String(ds.id)] = tables;
} catch (error) {
console.error(`加载数据源 ${ds.name} 表失败:`, error);
tablesMap[String(ds.id)] = [];
}
}
setAvailableTables(tablesMap);
};
loadTables();
}, []);
return (
<AdminModal isOpen={true} onClose={onClose} title={isEditing ? `管理 ${users[0].username} 的权限` : title}>
<div className="space-y-4 max-h-[60vh] overflow-y-auto pr-2">
{dataSources.map(ds => {
const currentPerm = perms.find(p => p.dataSourceId === String(ds.id));
// 简化处理:使用默认表列表,实际应该从后端获取
const allTablesForDs = ['table_1', 'table_2', 'table_3'];
const allSelected = currentPerm ? currentPerm.tables.length === allTablesForDs.length : false;
const allTablesForDs = availableTables[String(ds.id)] || [];
const allSelected = currentPerm ? currentPerm.tables.length === allTablesForDs.length && allTablesForDs.length > 0 : false;
return (
<div key={ds.id} className="p-3 border rounded-lg">
<h4 className="font-semibold mb-2">{ds.name}</h4>
<div className="border-t pt-2">
<label className="flex items-center mb-2 font-medium text-sm">
<input type="checkbox" onChange={(e) => handleSelectAllTables(String(ds.id), e.target.checked)} checked={allSelected} className="mr-2 h-4 w-4" />
</label>
<div className="grid grid-cols-2 gap-2 text-sm">
{allTablesForDs.map(table => (
<label key={table} className="flex items-center">
<input
type="checkbox"
checked={currentPerm?.tables.includes(table) || false}
onChange={(e) => handleTableToggle(String(ds.id), table, e.target.checked)}
className="mr-2 h-4 w-4"
/>
{table}
</label>
))}
</div>
{allTablesForDs.length > 0 ? (
<>
<label className="flex items-center mb-2 font-medium text-sm">
<input type="checkbox" onChange={(e) => handleSelectAllTables(String(ds.id), e.target.checked)} checked={allSelected} className="mr-2 h-4 w-4" />
</label>
<div className="grid grid-cols-2 gap-2 text-sm">
{allTablesForDs.map(table => (
<label key={table} className="flex items-center">
<input
type="checkbox"
checked={currentPerm?.tables.includes(table) || false}
onChange={(e) => handleTableToggle(String(ds.id), table, e.target.checked)}
className="mr-2 h-4 w-4"
/>
{table}
</label>
))}
</div>
</>
) : (
<div className="text-sm text-gray-500 py-2"></div>
)}
</div>
</div>
)

@ -111,6 +111,7 @@ export interface QueryRequest {
userPrompt: string;
model: string;
database: string;
dbConnectionId?: number; // 数据库连接ID用于实际连接
conversationId?: string;
}
@ -293,6 +294,10 @@ export const dbConnectionApi = {
test: async (id: number): Promise<boolean> => {
return await request<boolean>(`/db-connection/test/${id}`);
},
getTables: async (id: number): Promise<string[]> => {
return await request<string[]>(`/db-connection/${id}/tables`);
},
};
// ==================== 大模型配置接口 ====================

@ -37,7 +37,7 @@ export interface QueryResultData {
queryTime: string;
executionTime: string;
tableData: TableData;
chartData: ChartData;
chartData?: ChartData;
database: string;
model:string;
}

@ -0,0 +1,77 @@
/**
*
*
*/
import { operationLogApi } from '../services/api';
/**
*
* @param module
* @param operateType
* @param operateDesc
* @param status 1: , 0:
*/
export const logOperation = async (
module: string,
operateType: string,
operateDesc: string,
status: number = 1
): Promise<void> => {
try {
const userId = Number(sessionStorage.getItem('userId') || '0');
if (!userId) {
console.warn('未找到用户ID无法记录日志');
return;
}
await operationLogApi.create({
userId,
module,
operateType,
operateDesc,
status,
ip: 'unknown', // 前端无法直接获取IP由后端补充
});
} catch (error) {
console.error('记录操作日志失败:', error);
// 日志记录失败不应该影响业务操作,所以只打印错误
}
};
/**
*
*/
export const LogModule = {
USER_MANAGEMENT: '用户管理',
DATA_SOURCE: '数据源管理',
LLM_CONFIG: '大模型配置',
PERMISSION: '权限管理',
NOTIFICATION: '通知管理',
QUERY: '查询操作',
SYSTEM: '系统操作',
} as const;
/**
*
*/
export const LogOperationType = {
CREATE: '创建',
UPDATE: '更新',
DELETE: '删除',
ENABLE: '启用',
DISABLE: '禁用',
TEST: '测试',
ASSIGN: '分配',
PUBLISH: '发布',
LOGIN: '登录',
LOGOUT: '登出',
} as const;
/**
*
*/
export const LogStatus = {
SUCCESS: 1,
FAILURE: 0,
} as const;

@ -0,0 +1,17 @@
-- 更新现有的数据库连接 URL 格式
-- 如果你之前添加的数据源没有包含数据库名,需要手动更新
-- 示例:将 localhost:3306 更新为 localhost:3306/natural_language_query_system
-- UPDATE db_connections
-- SET url = CONCAT(url, '/natural_language_query_system')
-- WHERE url NOT LIKE '%/%';
-- 请根据实际情况修改数据库名称
-- 查看当前所有连接:
SELECT id, name, url FROM db_connections;
-- 更新特定连接请根据实际ID修改
-- UPDATE db_connections
-- SET url = 'localhost:3306/natural_language_query_system'
-- WHERE id = 1;

@ -3,6 +3,7 @@ package com.example.springboot_demo.controller;
import com.example.springboot_demo.common.Result;
import com.example.springboot_demo.entity.mysql.DbConnection;
import com.example.springboot_demo.service.DbConnectionService;
import com.example.springboot_demo.service.DatabaseSchemaService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@ -104,6 +105,22 @@ public class DbConnectionController {
boolean result = dbConnectionService.testConnection(id);
return Result.success(result);
}
@Autowired
private DatabaseSchemaService databaseSchemaService;
/**
*
*/
@GetMapping("/{id}/tables")
public Result<List<String>> getTables(@PathVariable Long id) {
DbConnection connection = dbConnectionService.getById(id);
if (connection == null) {
return Result.error("数据源不存在");
}
List<String> tables = databaseSchemaService.getTableNames(connection);
return Result.success(tables);
}
}

@ -5,8 +5,9 @@ import lombok.Data;
@Data
public class QueryRequestDTO {
private String userPrompt; // 用户的自然语言查询
private String model; // 使用的大模型
private String database; // 使用的数据库
private String model; // 使用的大模型LLM配置ID
private String database; // 使用的数据库兼容旧版实际使用dbConnectionId
private Long dbConnectionId; // 数据库连接ID新增
private String conversationId; // 对话ID用于多轮对话
}

@ -0,0 +1,33 @@
package com.example.springboot_demo.service;
import com.example.springboot_demo.entity.mysql.DbConnection;
import java.util.List;
import java.util.Map;
/**
*
*/
public interface DatabaseSchemaService {
/**
*
* @param connection
* @return
*/
String getDatabaseSchema(DbConnection connection);
/**
*
* @param connection
* @param tableName
* @return
*/
Map<String, Object> getTableSchema(DbConnection connection, String tableName);
/**
*
* @param connection
* @return
*/
List<String> getTableNames(DbConnection connection);
}

@ -1,5 +1,6 @@
package com.example.springboot_demo.service;
import com.example.springboot_demo.entity.mysql.DbConnection;
import java.util.Map;
/**
@ -14,6 +15,26 @@ public interface LlmService {
* @return SQLMap
*/
Map<String, Object> generateQuery(String prompt, String modelName, String databaseName);
/**
* SQL
* @param prompt
* @param modelName
* @param databaseName
* @param schemaInfo
* @return SQLMap
*/
Map<String, Object> generateQueryWithSchema(String prompt, String modelName, String databaseName, String schemaInfo);
/**
* SQL
* @param prompt
* @param modelConfigId ID
* @param databaseName
* @param dbConnection
* @return SQLMap
*/
Map<String, Object> generateQueryWithConnection(String prompt, String modelConfigId, String databaseName, DbConnection dbConnection);
}

@ -0,0 +1,264 @@
package com.example.springboot_demo.service.impl;
import com.example.springboot_demo.entity.mysql.DbConnection;
import com.example.springboot_demo.service.DatabaseSchemaService;
import org.springframework.stereotype.Service;
import java.sql.*;
import java.util.*;
/**
*
*/
@Service
public class DatabaseSchemaServiceImpl implements DatabaseSchemaService {
/**
* ID JDBC URL
*/
private String getJdbcUrlPrefix(Integer dbTypeId) {
if (dbTypeId == null) {
throw new IllegalArgumentException("数据库类型ID不能为空");
}
switch (dbTypeId) {
case 1: // MySQL
return "jdbc:mysql://";
case 2: // PostgreSQL
return "jdbc:postgresql://";
case 3: // Oracle
return "jdbc:oracle:thin:@";
case 4: // SQL Server
return "jdbc:sqlserver://";
default:
throw new IllegalArgumentException("不支持的数据库类型ID: " + dbTypeId);
}
}
/**
* ID JDBC
*/
private String getDriverClassName(Integer dbTypeId) {
if (dbTypeId == null) {
throw new IllegalArgumentException("数据库类型ID不能为空");
}
switch (dbTypeId) {
case 1: // MySQL
return "com.mysql.cj.jdbc.Driver";
case 2: // PostgreSQL
return "org.postgresql.Driver";
case 3: // Oracle
return "oracle.jdbc.driver.OracleDriver";
case 4: // SQL Server
return "com.microsoft.sqlserver.jdbc.SQLServerDriver";
default:
throw new IllegalArgumentException("不支持的数据库类型ID: " + dbTypeId);
}
}
/**
* JDBC URL
*/
private String buildJdbcUrl(DbConnection connection) {
String prefix = getJdbcUrlPrefix(connection.getDbTypeId());
String url = connection.getUrl();
// 对于 MySQL添加额外参数
if (connection.getDbTypeId() != null && connection.getDbTypeId() == 1) {
if (url.contains("?")) {
return prefix + url + "&useUnicode=true&characterEncoding=utf8&useSSL=false&serverTimezone=Asia/Shanghai";
} else {
return prefix + url + "?useUnicode=true&characterEncoding=utf8&useSSL=false&serverTimezone=Asia/Shanghai";
}
}
return prefix + url;
}
@Override
public String getDatabaseSchema(DbConnection connection) {
try {
List<String> tableNames = getTableNames(connection);
StringBuilder schema = new StringBuilder();
schema.append("数据库表结构信息:\n\n");
for (String tableName : tableNames) {
Map<String, Object> tableSchema = getTableSchema(connection, tableName);
schema.append(formatTableSchema(tableName, tableSchema));
schema.append("\n");
}
return schema.toString();
} catch (Exception e) {
System.err.println("获取数据库表结构失败: " + e.getMessage());
e.printStackTrace();
return "无法获取表结构信息";
}
}
@Override
public List<String> getTableNames(DbConnection connection) {
List<String> tableNames = new ArrayList<>();
try {
Class.forName(getDriverClassName(connection.getDbTypeId()));
String jdbcUrl = buildJdbcUrl(connection);
try (Connection conn = DriverManager.getConnection(
jdbcUrl,
connection.getUsername(),
connection.getPassword()
)) {
DatabaseMetaData metaData = conn.getMetaData();
// 从URL中提取数据库名
String databaseName = extractDatabaseName(connection.getUrl());
// 获取所有表
try (ResultSet tables = metaData.getTables(databaseName, null, "%", new String[]{"TABLE"})) {
while (tables.next()) {
String tableName = tables.getString("TABLE_NAME");
tableNames.add(tableName);
}
}
}
} catch (Exception e) {
System.err.println("获取表名列表失败: " + e.getMessage());
e.printStackTrace();
}
return tableNames;
}
@Override
public Map<String, Object> getTableSchema(DbConnection connection, String tableName) {
Map<String, Object> tableInfo = new LinkedHashMap<>();
List<Map<String, String>> columns = new ArrayList<>();
try {
Class.forName(getDriverClassName(connection.getDbTypeId()));
String jdbcUrl = buildJdbcUrl(connection);
try (Connection conn = DriverManager.getConnection(
jdbcUrl,
connection.getUsername(),
connection.getPassword()
)) {
DatabaseMetaData metaData = conn.getMetaData();
String databaseName = extractDatabaseName(connection.getUrl());
// 获取表注释MySQL特有
if (connection.getDbTypeId() == 1) {
try (Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery(
"SELECT TABLE_COMMENT FROM information_schema.TABLES " +
"WHERE TABLE_SCHEMA = '" + databaseName + "' AND TABLE_NAME = '" + tableName + "'"
)) {
if (rs.next()) {
tableInfo.put("comment", rs.getString("TABLE_COMMENT"));
}
} catch (Exception e) {
// 忽略获取注释失败的错误
}
}
// 获取列信息
try (ResultSet columnsRs = metaData.getColumns(databaseName, null, tableName, "%")) {
while (columnsRs.next()) {
Map<String, String> columnInfo = new LinkedHashMap<>();
columnInfo.put("name", columnsRs.getString("COLUMN_NAME"));
columnInfo.put("type", columnsRs.getString("TYPE_NAME"));
columnInfo.put("size", columnsRs.getString("COLUMN_SIZE"));
columnInfo.put("nullable", columnsRs.getInt("NULLABLE") == DatabaseMetaData.columnNullable ? "YES" : "NO");
String remarks = columnsRs.getString("REMARKS");
if (remarks != null && !remarks.isEmpty()) {
columnInfo.put("comment", remarks);
}
columns.add(columnInfo);
}
}
// 获取主键信息
List<String> primaryKeys = new ArrayList<>();
try (ResultSet pkRs = metaData.getPrimaryKeys(databaseName, null, tableName)) {
while (pkRs.next()) {
primaryKeys.add(pkRs.getString("COLUMN_NAME"));
}
}
tableInfo.put("tableName", tableName);
tableInfo.put("columns", columns);
tableInfo.put("primaryKeys", primaryKeys);
}
} catch (Exception e) {
System.err.println("获取表结构失败: " + tableName + " - " + e.getMessage());
e.printStackTrace();
}
return tableInfo;
}
/**
* URL
*/
private String extractDatabaseName(String url) {
// 格式host:port/database 或 host:port/database?params
if (url.contains("/")) {
String[] parts = url.split("/");
if (parts.length >= 2) {
String dbPart = parts[parts.length - 1];
// 去除参数
if (dbPart.contains("?")) {
dbPart = dbPart.split("\\?")[0];
}
return dbPart;
}
}
return null;
}
/**
*
*/
private String formatTableSchema(String tableName, Map<String, Object> tableSchema) {
StringBuilder sb = new StringBuilder();
sb.append("表名: ").append(tableName);
if (tableSchema.containsKey("comment")) {
sb.append(" (").append(tableSchema.get("comment")).append(")");
}
sb.append("\n");
@SuppressWarnings("unchecked")
List<Map<String, String>> columns = (List<Map<String, String>>) tableSchema.get("columns");
if (columns != null && !columns.isEmpty()) {
sb.append("字段:\n");
for (Map<String, String> column : columns) {
sb.append(" - ").append(column.get("name"))
.append(" (").append(column.get("type"))
.append(", 可空: ").append(column.get("nullable"))
.append(")");
if (column.containsKey("comment")) {
sb.append(" // ").append(column.get("comment"));
}
sb.append("\n");
}
}
@SuppressWarnings("unchecked")
List<String> primaryKeys = (List<String>) tableSchema.get("primaryKeys");
if (primaryKeys != null && !primaryKeys.isEmpty()) {
sb.append("主键: ").append(String.join(", ", primaryKeys)).append("\n");
}
return sb.toString();
}
}

@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.example.springboot_demo.entity.mysql.DbConnection;
import com.example.springboot_demo.mapper.DbConnectionMapper;
import com.example.springboot_demo.service.DbConnectionService;
import com.example.springboot_demo.utils.DynamicDatabaseExecutor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
@ -12,6 +14,9 @@ import java.util.List;
@Service
public class DbConnectionServiceImpl extends ServiceImpl<DbConnectionMapper, DbConnection> implements DbConnectionService {
@Autowired
private DynamicDatabaseExecutor databaseExecutor;
@Override
public List<DbConnection> listByCreateUserId(Long createUserId) {
LambdaQueryWrapper<DbConnection> wrapper = new LambdaQueryWrapper<>();
@ -22,14 +27,22 @@ public class DbConnectionServiceImpl extends ServiceImpl<DbConnectionMapper, DbC
@Override
public boolean testConnection(Long id) {
// TODO: 实现真实的数据库连接测试逻辑
// 暂时返回 Mock 结果
DbConnection connection = getById(id);
if (connection == null) {
System.err.println("数据库连接不存在: ID=" + id);
return false;
}
try {
// 使用动态数据库执行器测试连接
boolean result = databaseExecutor.testConnection(connection);
System.out.println("数据库连接测试结果: " + (result ? "成功" : "失败") + " - " + connection.getName());
return result;
} catch (Exception e) {
System.err.println("数据库连接测试异常: " + e.getMessage());
e.printStackTrace();
return false;
}
// 这里应该根据 db_type_id 和连接信息实际测试连接
return true;
}
}

@ -14,7 +14,9 @@ import org.springframework.stereotype.Service;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject;
import com.example.springboot_demo.entity.mysql.DbConnection;
import com.example.springboot_demo.entity.mysql.LlmConfig;
import com.example.springboot_demo.service.DatabaseSchemaService;
import com.example.springboot_demo.service.LlmConfigService;
import com.example.springboot_demo.service.LlmService;
@ -27,6 +29,9 @@ public class LlmServiceImpl implements LlmService {
@Autowired
private LlmConfigService llmConfigService;
@Autowired
private DatabaseSchemaService databaseSchemaService;
private final HttpClient httpClient = HttpClient.newBuilder()
.connectTimeout(Duration.ofSeconds(30))
@ -34,6 +39,11 @@ public class LlmServiceImpl implements LlmService {
@Override
public Map<String, Object> generateQuery(String prompt, String modelConfigId, String databaseName) {
return generateQueryWithSchema(prompt, modelConfigId, databaseName, null);
}
@Override
public Map<String, Object> generateQueryWithSchema(String prompt, String modelConfigId, String databaseName, String schemaInfo) {
try {
// 根据配置ID从数据库获取模型配置
LlmConfig config = llmConfigService.getById(Long.valueOf(modelConfigId));
@ -45,7 +55,7 @@ public class LlmServiceImpl implements LlmService {
}
// 统一调用大模型API
return callLlmApi(prompt, config, databaseName);
return callLlmApi(prompt, config, databaseName, schemaInfo);
} catch (NumberFormatException e) {
throw new RuntimeException("无效的模型配置ID: " + modelConfigId);
} catch (Exception e) {
@ -57,7 +67,7 @@ public class LlmServiceImpl implements LlmService {
* API
* OpenAI Chat Completions API
*/
private Map<String, Object> callLlmApi(String prompt, LlmConfig config, String databaseName) throws Exception {
private Map<String, Object> callLlmApi(String prompt, LlmConfig config, String databaseName, String schemaInfo) throws Exception {
String apiKey = config.getApiKey().trim();
String url = config.getApiUrl().trim();
String modelName = config.getVersion().trim();
@ -74,7 +84,7 @@ public class LlmServiceImpl implements LlmService {
requestBody.put("model", modelName);
requestBody.put("messages", Arrays.asList(Map.of(
"role", "user",
"content", generatePrompt(prompt, databaseName)
"content", generatePrompt(prompt, databaseName, schemaInfo)
)));
requestBody.put("response_format", Map.of("type", "json_object"));
requestBody.put("temperature", 0.0);
@ -123,37 +133,49 @@ public class LlmServiceImpl implements LlmService {
}
/**
* Prompt
* Prompt
*/
private String generatePrompt(String prompt, String databaseName) {
return String.format(
"你是数据查询助手需将用户请求转换为指定JSON格式。\n" +
"连接的数据库为\"%s\"仅生成该数据库的SQL。\n" +
"响应必须是单个有效的JSON对象不包含任何额外文本或格式如```json。\n\n" +
"用户请求:\"%s\"\n\n" +
"规则:\n" +
"- 数据查询可SQL回答success=true生成SQL、表格数据和图表数据\n" +
"- 非数据查询success=false表格数据用[\"Message\"]和[\"抱歉,仅支持数据查询\"]\n\n" +
"返回JSON格式\n" +
"{\n" +
" \"success\": true/false,\n" +
" \"sqlQuery\": \"SQL语句\",\n" +
" \"tableData\": {\n" +
" \"headers\": [\"列1\", \"列2\"],\n" +
" \"rows\": [[\"值1\", \"值2\"]]\n" +
" },\n" +
" \"chartData\": {\n" +
" \"type\": \"bar/line/pie\",\n" +
" \"labels\": [\"标签1\"],\n" +
" \"datasets\": [{\n" +
" \"label\": \"数据标签\",\n" +
" \"data\": [1, 2, 3],\n" +
" \"backgroundColor\": \"rgba(22, 93, 255, 0.6)\"\n" +
" }]\n" +
" }\n" +
"}",
databaseName, prompt
);
private String generatePrompt(String prompt, String databaseName, String schemaInfo) {
StringBuilder promptBuilder = new StringBuilder();
promptBuilder.append("你是数据查询助手需将用户请求转换为指定JSON格式。\n");
promptBuilder.append("连接的数据库为\"").append(databaseName).append("\"仅生成该数据库的SQL。\n");
promptBuilder.append("响应必须是单个有效的JSON对象不包含任何额外文本或格式如```json。\n\n");
// 如果有表结构信息添加到Prompt中
if (schemaInfo != null && !schemaInfo.isEmpty()) {
promptBuilder.append("=== 数据库表结构信息 ===\n");
promptBuilder.append(schemaInfo);
promptBuilder.append("\n请根据上述真实的表结构生成SQL确保使用正确的表名和列名。\n");
promptBuilder.append("注意:必须使用实际存在的列名,不要猜测或假设列名。\n\n");
}
promptBuilder.append("用户请求:\"").append(prompt).append("\"\n\n");
promptBuilder.append("规则:\n");
promptBuilder.append("- 数据查询可SQL回答success=true生成SQL、表格数据和图表数据\n");
promptBuilder.append("- 非数据查询success=false表格数据用[\"Message\"]和[\"抱歉,仅支持数据查询\"]\n");
promptBuilder.append("- 必须使用上述表结构中实际存在的列名\n");
promptBuilder.append("- SQL语句必须符合MySQL语法\n\n");
promptBuilder.append("返回JSON格式\n");
promptBuilder.append("{\n");
promptBuilder.append(" \"success\": true/false,\n");
promptBuilder.append(" \"sqlQuery\": \"SQL语句\",\n");
promptBuilder.append(" \"tableData\": {\n");
promptBuilder.append(" \"headers\": [\"列1\", \"列2\"],\n");
promptBuilder.append(" \"rows\": [[\"值1\", \"值2\"]]\n");
promptBuilder.append(" },\n");
promptBuilder.append(" \"chartData\": {\n");
promptBuilder.append(" \"type\": \"bar/line/pie\",\n");
promptBuilder.append(" \"labels\": [\"标签1\"],\n");
promptBuilder.append(" \"datasets\": [{\n");
promptBuilder.append(" \"label\": \"数据标签\",\n");
promptBuilder.append(" \"data\": [1, 2, 3],\n");
promptBuilder.append(" \"backgroundColor\": \"rgba(22, 93, 255, 0.6)\"\n");
promptBuilder.append(" }]\n");
promptBuilder.append(" }\n");
promptBuilder.append("}");
return promptBuilder.toString();
}
/**
@ -172,4 +194,99 @@ public class LlmServiceImpl implements LlmService {
throw new RuntimeException("解析模型响应失败: " + e.getMessage(), e);
}
}
@Override
public Map<String, Object> generateQueryWithConnection(String prompt, String modelConfigId, String databaseName, DbConnection dbConnection) {
try {
// 获取数据库表结构
System.out.println("✓ 开始获取数据库表结构信息...");
String schemaInfo = databaseSchemaService.getDatabaseSchema(dbConnection);
System.out.println("✓ 已获取数据库表结构信息");
// 根据配置ID从数据库获取模型配置
LlmConfig config = llmConfigService.getById(Long.valueOf(modelConfigId));
if (config == null) {
throw new RuntimeException("模型配置不存在ID: " + modelConfigId);
}
if (config.getIsDisabled() == 1) {
throw new RuntimeException("该模型配置已被禁用");
}
// 使用包含表结构的 prompt 调用大模型
return callLlmApiWithSchema(prompt, config, databaseName, schemaInfo);
} catch (NumberFormatException e) {
throw new RuntimeException("无效的模型配置ID: " + modelConfigId);
} catch (Exception e) {
throw new RuntimeException("模型调用失败: " + e.getMessage(), e);
}
}
/**
* API
*/
private Map<String, Object> callLlmApiWithSchema(String prompt, LlmConfig config, String databaseName, String schemaInfo) throws Exception {
String apiKey = config.getApiKey().trim();
String url = config.getApiUrl().trim();
String modelName = config.getVersion().trim();
// 打印调试信息
System.out.println("=== LLM API 调用信息(含表结构) ===");
System.out.println("配置名称: " + config.getName() + " (ID: " + config.getId() + ")");
System.out.println("API URL: " + url);
System.out.println("模型名称: " + modelName);
// 构建包含表结构的 prompt
String enhancedPrompt = generatePrompt(prompt, databaseName, schemaInfo);
// 构建请求体OpenAI Chat Completions API 格式)
JSONObject requestBody = new JSONObject();
requestBody.put("model", modelName);
requestBody.put("messages", Arrays.asList(Map.of(
"role", "user",
"content", enhancedPrompt
)));
requestBody.put("response_format", Map.of("type", "json_object"));
requestBody.put("temperature", 0.0);
System.out.println("发送请求到大模型...");
// 发送HTTP请求
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + apiKey)
.POST(HttpRequest.BodyPublishers.ofString(requestBody.toJSONString()))
.timeout(Duration.ofSeconds(config.getTimeout() != null ? config.getTimeout() / 1000 : 60))
.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
System.out.println("响应状态码: " + response.statusCode());
System.out.println("=========================");
if (response.statusCode() != 200) {
throw new RuntimeException("API调用失败: " + response.statusCode() + ", 响应: " + response.body());
}
JSONObject jsonResponse = JSON.parseObject(response.body());
// 解析响应OpenAI格式
if (!jsonResponse.containsKey("choices") || jsonResponse.getJSONArray("choices").isEmpty()) {
throw new RuntimeException("API响应格式错误缺少choices字段");
}
JSONObject choice = jsonResponse.getJSONArray("choices").getJSONObject(0);
if (!choice.containsKey("message")) {
throw new RuntimeException("API响应格式错误缺少message字段");
}
String content = choice.getJSONObject("message").getString("content");
if (content == null || content.isEmpty()) {
throw new RuntimeException("API返回内容为空");
}
// 清理可能的markdown代码块标记
String cleanedContent = content.replaceAll("^```json\\n|```$", "").trim();
return parseJsonResponse(cleanedContent);
}
}

@ -4,9 +4,12 @@ import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.example.springboot_demo.dto.QueryRequestDTO;
import com.example.springboot_demo.entity.mongodb.DialogRecord;
import com.example.springboot_demo.entity.mysql.DbConnection;
import com.example.springboot_demo.repository.DialogRecordRepository;
import com.example.springboot_demo.service.DbConnectionService;
import com.example.springboot_demo.service.LlmService;
import com.example.springboot_demo.service.QueryService;
import com.example.springboot_demo.utils.DynamicDatabaseExecutor;
import com.example.springboot_demo.vo.QueryResponseVO;
import com.example.springboot_demo.vo.TableDataVO;
import com.example.springboot_demo.vo.ChartDataVO;
@ -28,10 +31,34 @@ public class QueryServiceImpl implements QueryService {
@Autowired
private LlmService llmService;
@Autowired
private DbConnectionService dbConnectionService;
@Autowired
private DynamicDatabaseExecutor databaseExecutor;
@Override
public QueryResponseVO executeQuery(QueryRequestDTO request, Long userId) {
long startTime = System.currentTimeMillis();
// 验证必要参数
if (request.getDbConnectionId() == null) {
throw new RuntimeException("数据库连接ID不能为空");
}
// 获取数据库连接配置
DbConnection dbConnection = dbConnectionService.getById(request.getDbConnectionId());
if (dbConnection == null) {
throw new RuntimeException("数据库连接不存在ID: " + request.getDbConnectionId());
}
if ("disabled".equals(dbConnection.getStatus())) {
throw new RuntimeException("该数据库连接已被禁用");
}
System.out.println("=== 查询执行开始 ===");
System.out.println("数据库连接: " + dbConnection.getName() + " (ID: " + dbConnection.getId() + ")");
System.out.println("用户提示: " + request.getUserPrompt());
// 生成或获取对话ID
String conversationId = request.getConversationId();
if (conversationId == null || conversationId.isEmpty()) {
@ -61,13 +88,69 @@ public class QueryServiceImpl implements QueryService {
}
}
// 调用大模型API生成SQL和结果
Map<String, Object> llmResult = llmService.generateQuery(
// 调用大模型API生成SQL传递数据库连接以获取表结构
String databaseName = request.getDatabase() != null ? request.getDatabase() : dbConnection.getName();
Map<String, Object> llmResult = llmService.generateQueryWithConnection(
request.getUserPrompt(),
request.getModel(),
request.getDatabase()
databaseName,
dbConnection // 传递数据库连接让LLM服务自动获取表结构
);
String generatedSql = (String) llmResult.getOrDefault("sqlQuery", "");
System.out.println("✓ 大模型生成 SQL: " + generatedSql);
// 执行SQL获取真实数据
TableDataVO realTableData = null;
boolean executionSuccess = false;
try {
if (generatedSql != null && !generatedSql.trim().isEmpty()) {
// 使用动态数据库执行器执行SQL
Map<String, Object> queryResult = databaseExecutor.executeQuery(dbConnection, generatedSql);
// 转换为TableDataVO
realTableData = new TableDataVO();
// 安全地处理 headers
Object headersObj = queryResult.get("headers");
List<String> headers = new ArrayList<>();
if (headersObj instanceof List) {
for (Object header : (List<?>) headersObj) {
headers.add(header != null ? header.toString() : "");
}
}
realTableData.setHeaders(headers);
// 安全地处理 rows
Object rowsObj = queryResult.get("rows");
List<List<String>> stringRows = new ArrayList<>();
if (rowsObj instanceof List) {
for (Object rowObj : (List<?>) rowsObj) {
List<String> rowList = new ArrayList<>();
if (rowObj instanceof List) {
for (Object cell : (List<?>) rowObj) {
rowList.add(cell != null ? cell.toString() : "");
}
}
stringRows.add(rowList);
}
}
realTableData.setRows(stringRows);
executionSuccess = true;
System.out.println("✓ SQL执行成功返回 " + stringRows.size() + " 行数据");
}
} catch (Exception e) {
System.err.println("✗ SQL执行失败: " + e.getMessage());
// SQL执行失败时使用错误信息创建表格数据
realTableData = new TableDataVO();
realTableData.setHeaders(Arrays.asList("错误信息"));
realTableData.setRows(Arrays.asList(
Arrays.asList("SQL执行失败: " + e.getMessage())
));
}
// 计算执行时间
long endTime = System.currentTimeMillis();
String executionTime = String.format("%.1f秒", (endTime - startTime) / 1000.0);
@ -76,30 +159,136 @@ public class QueryServiceImpl implements QueryService {
QueryResponseVO response = new QueryResponseVO();
response.setId("query_" + UUID.randomUUID().toString().substring(0, 8));
response.setUserPrompt(request.getUserPrompt());
response.setSqlQuery((String) llmResult.getOrDefault("sqlQuery", ""));
response.setSqlQuery(generatedSql);
response.setConversationId(conversationId);
response.setQueryTime(LocalDateTime.now().format(DateTimeFormatter.ISO_DATE_TIME));
response.setExecutionTime(executionTime);
response.setDatabase(request.getDatabase());
response.setDatabase(databaseName);
response.setModel(request.getModel());
// 解析表格数据
Object tableDataObj = llmResult.get("tableData");
if (tableDataObj != null) {
TableDataVO tableData = parseTableData(tableDataObj);
response.setTableData(tableData);
// 使用真实的表格数据(如果执行成功)或大模型生成的数据
if (executionSuccess && realTableData != null) {
response.setTableData(realTableData);
} else if (realTableData != null) {
// SQL执行失败显示错误信息
response.setTableData(realTableData);
} else {
// 降级:使用大模型生成的模拟数据
Object tableDataObj = llmResult.get("tableData");
if (tableDataObj != null) {
TableDataVO tableData = parseTableData(tableDataObj);
response.setTableData(tableData);
}
}
// 尝试从真实数据生成图表(如果大模型生成的图表为空,或我们想优先使用真实数据)
// 目前策略如果SQL执行成功且有数据优先尝试从真实数据生成图表
ChartDataVO realChartData = null;
if (executionSuccess && realTableData != null &&
realTableData.getRows() != null && !realTableData.getRows().isEmpty() &&
realTableData.getHeaders() != null && realTableData.getHeaders().size() >= 2) {
try {
realChartData = generateChartFromData(realTableData);
} catch (Exception e) {
System.err.println("从真实数据生成图表失败: " + e.getMessage());
}
}
// 解析图表数据
Object chartDataObj = llmResult.get("chartData");
if (chartDataObj != null) {
ChartDataVO chartData = parseChartData(chartDataObj);
response.setChartData(chartData);
// 解析图表数据(优先使用从真实数据生成的图表)
if (realChartData != null) {
response.setChartData(realChartData);
} else {
// 降级:使用大模型生成的图表(可能是空的或基于幻觉的)
Object chartDataObj = llmResult.get("chartData");
if (chartDataObj != null) {
ChartDataVO chartData = parseChartData(chartDataObj);
// 只有当大模型返回的图表有数据时才使用
if (chartData != null && chartData.getDatasets() != null && !chartData.getDatasets().isEmpty()) {
response.setChartData(chartData);
}
}
}
System.out.println("=== 查询执行完成 ===");
return response;
}
/**
*
*
* 1. Label (X)
* 2. Data (Y)
* 3. 使 (bar)
*/
private ChartDataVO generateChartFromData(TableDataVO tableData) {
List<String> headers = tableData.getHeaders();
List<List<String>> rows = tableData.getRows();
if (headers.size() < 2 || rows.isEmpty()) {
return null;
}
// 1. 获取 Labels (第一列)
List<String> labels = rows.stream()
.map(row -> row.size() > 0 ? row.get(0) : "")
.collect(Collectors.toList());
// 2. 寻找数值列
int valueColumnIndex = -1;
for (int i = 1; i < headers.size(); i++) {
boolean isNumeric = true;
for (List<String> row : rows) {
if (row.size() <= i) continue;
String val = row.get(i);
if (val == null || val.isEmpty()) continue; // 跳过空值
try {
// 尝试解析为数字(处理可能的逗号等格式,如 "1,234"
Double.parseDouble(val.replace(",", ""));
} catch (NumberFormatException e) {
isNumeric = false;
break;
}
}
if (isNumeric) {
valueColumnIndex = i;
break; // 找到第一个数值列即可
}
}
if (valueColumnIndex == -1) {
return null; // 没有找到数值列,无法生成图表
}
// 3. 构建图表数据
ChartDataVO chartData = new ChartDataVO();
chartData.setType("bar"); // 默认柱状图
chartData.setLabels(labels);
DatasetVO dataset = new DatasetVO();
dataset.setLabel(headers.get(valueColumnIndex)); // 使用列名作为图例
dataset.setBackgroundColor("rgba(54, 162, 235, 0.6)"); // 默认蓝色
final int colIdx = valueColumnIndex;
List<Number> data = rows.stream()
.map(row -> {
if (row.size() <= colIdx) return 0.0;
String val = row.get(colIdx);
if (val == null || val.isEmpty()) return 0.0;
try {
return Double.parseDouble(val.replace(",", ""));
} catch (NumberFormatException e) {
return 0.0;
}
})
.collect(Collectors.toList());
dataset.setData(data);
chartData.setDatasets(Collections.singletonList(dataset));
System.out.println("✓ 已根据真实数据自动生成图表");
return chartData;
}
/**
*
*/

@ -0,0 +1,414 @@
package com.example.springboot_demo.utils;
import com.example.springboot_demo.entity.mysql.DbConnection;
import org.springframework.stereotype.Component;
import java.sql.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*
* DbConnection SQL
*/
@Component
public class DynamicDatabaseExecutor {
/**
*
*/
public static class TableSchema {
private String tableName;
private String tableComment;
private List<ColumnInfo> columns;
public TableSchema() {
this.columns = new ArrayList<>();
}
public String getTableName() { return tableName; }
public void setTableName(String tableName) { this.tableName = tableName; }
public String getTableComment() { return tableComment; }
public void setTableComment(String tableComment) { this.tableComment = tableComment; }
public List<ColumnInfo> getColumns() { return columns; }
public void setColumns(List<ColumnInfo> columns) { this.columns = columns; }
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("表名: ").append(tableName);
if (tableComment != null && !tableComment.isEmpty()) {
sb.append(" (").append(tableComment).append(")");
}
sb.append("\n列信息:\n");
for (ColumnInfo col : columns) {
sb.append(" - ").append(col.columnName)
.append(" (").append(col.dataType).append(")");
if (col.columnComment != null && !col.columnComment.isEmpty()) {
sb.append(" - ").append(col.columnComment);
}
if (col.isPrimaryKey) {
sb.append(" [主键]");
}
sb.append("\n");
}
return sb.toString();
}
}
/**
*
*/
public static class ColumnInfo {
private String columnName;
private String dataType;
private String columnComment;
private boolean isPrimaryKey;
private boolean isNullable;
public String getColumnName() { return columnName; }
public void setColumnName(String columnName) { this.columnName = columnName; }
public String getDataType() { return dataType; }
public void setDataType(String dataType) { this.dataType = dataType; }
public String getColumnComment() { return columnComment; }
public void setColumnComment(String columnComment) { this.columnComment = columnComment; }
public boolean isPrimaryKey() { return isPrimaryKey; }
public void setPrimaryKey(boolean primaryKey) { isPrimaryKey = primaryKey; }
public boolean isNullable() { return isNullable; }
public void setNullable(boolean nullable) { isNullable = nullable; }
}
/**
* ID JDBC URL
*/
private String getJdbcUrlPrefix(Integer dbTypeId) {
if (dbTypeId == null) {
throw new IllegalArgumentException("数据库类型ID不能为空");
}
switch (dbTypeId) {
case 1: // MySQL
return "jdbc:mysql://";
case 2: // PostgreSQL
return "jdbc:postgresql://";
case 3: // Oracle
return "jdbc:oracle:thin:@";
case 4: // SQL Server
return "jdbc:sqlserver://";
default:
throw new IllegalArgumentException("不支持的数据库类型ID: " + dbTypeId);
}
}
/**
* ID JDBC
*/
private String getDriverClassName(Integer dbTypeId) {
if (dbTypeId == null) {
throw new IllegalArgumentException("数据库类型ID不能为空");
}
switch (dbTypeId) {
case 1: // MySQL
return "com.mysql.cj.jdbc.Driver";
case 2: // PostgreSQL
return "org.postgresql.Driver";
case 3: // Oracle
return "oracle.jdbc.driver.OracleDriver";
case 4: // SQL Server
return "com.microsoft.sqlserver.jdbc.SQLServerDriver";
default:
throw new IllegalArgumentException("不支持的数据库类型ID: " + dbTypeId);
}
}
/**
* JDBC URL
*/
private String buildJdbcUrl(DbConnection connection) {
String prefix = getJdbcUrlPrefix(connection.getDbTypeId());
String url = connection.getUrl(); // 格式: host:port 或 host:port/database
// 对于 MySQL添加额外参数
if (connection.getDbTypeId() != null && connection.getDbTypeId() == 1L) {
if (url.contains("?")) {
return prefix + url + "&useUnicode=true&characterEncoding=utf8&useSSL=false&serverTimezone=Asia/Shanghai";
} else {
return prefix + url + "?useUnicode=true&characterEncoding=utf8&useSSL=false&serverTimezone=Asia/Shanghai";
}
}
return prefix + url;
}
/**
*
* @param connection
* @return
*/
public boolean testConnection(DbConnection connection) {
try {
// 加载驱动
Class.forName(getDriverClassName(connection.getDbTypeId()));
// 构建 JDBC URL
String jdbcUrl = buildJdbcUrl(connection);
System.out.println("测试连接: " + jdbcUrl);
System.out.println("用户名: " + connection.getUsername());
// 尝试连接
try (Connection conn = DriverManager.getConnection(
jdbcUrl,
connection.getUsername(),
connection.getPassword()
)) {
return conn != null && !conn.isClosed();
}
} catch (Exception e) {
System.err.println("数据库连接测试失败: " + e.getMessage());
e.printStackTrace();
return false;
}
}
/**
* SQL
* @param connection
* @param sql SQL
* @return headers rows Map
*/
public Map<String, Object> executeQuery(DbConnection connection, String sql) {
Map<String, Object> result = new HashMap<>();
List<String> headers = new ArrayList<>();
List<List<Object>> rows = new ArrayList<>();
try {
// 加载驱动
Class.forName(getDriverClassName(connection.getDbTypeId()));
// 构建 JDBC URL
String jdbcUrl = buildJdbcUrl(connection);
System.out.println("执行查询 URL: " + jdbcUrl);
System.out.println("执行 SQL: " + sql);
// 连接数据库并执行查询
try (Connection conn = DriverManager.getConnection(
jdbcUrl,
connection.getUsername(),
connection.getPassword()
)) {
try (Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery(sql)) {
// 获取列信息
ResultSetMetaData metaData = rs.getMetaData();
int columnCount = metaData.getColumnCount();
// 提取列名
for (int i = 1; i <= columnCount; i++) {
headers.add(metaData.getColumnLabel(i));
}
// 提取数据行
while (rs.next()) {
List<Object> row = new ArrayList<>();
for (int i = 1; i <= columnCount; i++) {
Object value = rs.getObject(i);
row.add(value != null ? value.toString() : "");
}
rows.add(row);
}
}
}
result.put("headers", headers);
result.put("rows", rows);
System.out.println("查询成功,返回 " + rows.size() + " 行数据");
} catch (Exception e) {
System.err.println("SQL执行失败: " + e.getMessage());
e.printStackTrace();
throw new RuntimeException("SQL执行失败: " + e.getMessage(), e);
}
return result;
}
/**
* SQL INSERT, UPDATE, DELETE
* @param connection
* @param sql SQL
* @return
*/
public int executeUpdate(DbConnection connection, String sql) {
try {
// 加载驱动
Class.forName(getDriverClassName(connection.getDbTypeId()));
// 构建 JDBC URL
String jdbcUrl = buildJdbcUrl(connection);
System.out.println("执行更新 URL: " + jdbcUrl);
System.out.println("执行 SQL: " + sql);
// 连接数据库并执行更新
try (Connection conn = DriverManager.getConnection(
jdbcUrl,
connection.getUsername(),
connection.getPassword()
)) {
try (Statement stmt = conn.createStatement()) {
int affectedRows = stmt.executeUpdate(sql);
System.out.println("更新成功,影响 " + affectedRows + " 行");
return affectedRows;
}
}
} catch (Exception e) {
System.err.println("SQL更新失败: " + e.getMessage());
e.printStackTrace();
throw new RuntimeException("SQL更新失败: " + e.getMessage(), e);
}
}
/**
*
* @param connection
* @return
*/
public List<TableSchema> getDatabaseSchema(DbConnection connection) {
List<TableSchema> schemas = new ArrayList<>();
try {
// 加载驱动
Class.forName(getDriverClassName(connection.getDbTypeId()));
// 构建 JDBC URL
String jdbcUrl = buildJdbcUrl(connection);
System.out.println("获取数据库结构: " + jdbcUrl);
// 连接数据库
try (Connection conn = DriverManager.getConnection(
jdbcUrl,
connection.getUsername(),
connection.getPassword()
)) {
DatabaseMetaData metaData = conn.getMetaData();
// 从URL中提取数据库名
String databaseName = extractDatabaseName(connection.getUrl());
// 获取所有表
try (ResultSet tables = metaData.getTables(databaseName, null, "%", new String[]{"TABLE"})) {
while (tables.next()) {
TableSchema tableSchema = new TableSchema();
String tableName = tables.getString("TABLE_NAME");
String tableComment = tables.getString("REMARKS");
tableSchema.setTableName(tableName);
tableSchema.setTableComment(tableComment);
// 获取该表的所有列
List<ColumnInfo> columns = new ArrayList<>();
try (ResultSet columns_rs = metaData.getColumns(databaseName, null, tableName, "%")) {
while (columns_rs.next()) {
ColumnInfo columnInfo = new ColumnInfo();
columnInfo.setColumnName(columns_rs.getString("COLUMN_NAME"));
columnInfo.setDataType(columns_rs.getString("TYPE_NAME"));
columnInfo.setColumnComment(columns_rs.getString("REMARKS"));
columnInfo.setNullable("YES".equalsIgnoreCase(columns_rs.getString("IS_NULLABLE")));
columns.add(columnInfo);
}
}
// 获取主键信息
try (ResultSet primaryKeys = metaData.getPrimaryKeys(databaseName, null, tableName)) {
while (primaryKeys.next()) {
String pkColumnName = primaryKeys.getString("COLUMN_NAME");
for (ColumnInfo col : columns) {
if (col.getColumnName().equals(pkColumnName)) {
col.setPrimaryKey(true);
break;
}
}
}
}
tableSchema.setColumns(columns);
schemas.add(tableSchema);
}
}
System.out.println("成功获取 " + schemas.size() + " 个表的结构信息");
}
} catch (Exception e) {
System.err.println("获取数据库结构失败: " + e.getMessage());
e.printStackTrace();
throw new RuntimeException("获取数据库结构失败: " + e.getMessage(), e);
}
return schemas;
}
/**
* URL
* @param url "localhost:3306/database_name"
* @return
*/
private String extractDatabaseName(String url) {
if (url.contains("/")) {
String[] parts = url.split("/");
if (parts.length > 1) {
// 移除可能的查询参数
String dbName = parts[parts.length - 1];
if (dbName.contains("?")) {
dbName = dbName.substring(0, dbName.indexOf("?"));
}
return dbName;
}
}
return null;
}
/**
*
* @param schemas
* @return
*/
public String formatSchemaForLLM(List<TableSchema> schemas) {
StringBuilder sb = new StringBuilder();
sb.append("数据库表结构信息:\n\n");
for (TableSchema schema : schemas) {
sb.append("表名: ").append(schema.getTableName());
if (schema.getTableComment() != null && !schema.getTableComment().isEmpty()) {
sb.append(" (").append(schema.getTableComment()).append(")");
}
sb.append("\n");
for (ColumnInfo col : schema.getColumns()) {
sb.append(" - ").append(col.getColumnName())
.append(": ").append(col.getDataType());
if (col.getColumnComment() != null && !col.getColumnComment().isEmpty()) {
sb.append(" // ").append(col.getColumnComment());
}
if (col.isPrimaryKey()) {
sb.append(" [主键]");
}
sb.append("\n");
}
sb.append("\n");
}
return sb.toString();
}
}

@ -22,9 +22,9 @@ spring:
data:
mongodb:
# 如果 MongoDB 没有启用认证,使用下面的配置(无用户名密码)
uri: mongodb://127.0.0.1:27017/natural_language_query_system
# uri: mongodb://127.0.0.1:27017/natural_language_query_system
# 如果 MongoDB 启用了认证,使用下面的配置(需要替换为正确的用户名和密码)
# uri: mongodb://admin:admin123456@127.0.0.1:27017/natural_language_query_system?authSource=admin
uri: mongodb://admin:admin123456@127.0.0.1:27017/natural_language_query_system?authSource=admin
# Redis 配置(可选)
redis:

Loading…
Cancel
Save