|
|
|
|
@ -14,7 +14,9 @@ import org.springframework.stereotype.Service;
|
|
|
|
|
|
|
|
|
|
import com.alibaba.fastjson2.JSON;
|
|
|
|
|
import com.alibaba.fastjson2.JSONObject;
|
|
|
|
|
import com.example.springboot_demo.entity.mysql.DbConnection;
|
|
|
|
|
import com.example.springboot_demo.entity.mysql.LlmConfig;
|
|
|
|
|
import com.example.springboot_demo.service.DatabaseSchemaService;
|
|
|
|
|
import com.example.springboot_demo.service.LlmConfigService;
|
|
|
|
|
import com.example.springboot_demo.service.LlmService;
|
|
|
|
|
|
|
|
|
|
@ -27,6 +29,9 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
|
private LlmConfigService llmConfigService;
|
|
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
|
private DatabaseSchemaService databaseSchemaService;
|
|
|
|
|
|
|
|
|
|
private final HttpClient httpClient = HttpClient.newBuilder()
|
|
|
|
|
.connectTimeout(Duration.ofSeconds(30))
|
|
|
|
|
@ -34,6 +39,11 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Map<String, Object> generateQuery(String prompt, String modelConfigId, String databaseName) {
|
|
|
|
|
return generateQueryWithSchema(prompt, modelConfigId, databaseName, null);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Map<String, Object> generateQueryWithSchema(String prompt, String modelConfigId, String databaseName, String schemaInfo) {
|
|
|
|
|
try {
|
|
|
|
|
// 根据配置ID从数据库获取模型配置
|
|
|
|
|
LlmConfig config = llmConfigService.getById(Long.valueOf(modelConfigId));
|
|
|
|
|
@ -45,7 +55,7 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 统一调用大模型API
|
|
|
|
|
return callLlmApi(prompt, config, databaseName);
|
|
|
|
|
return callLlmApi(prompt, config, databaseName, schemaInfo);
|
|
|
|
|
} catch (NumberFormatException e) {
|
|
|
|
|
throw new RuntimeException("无效的模型配置ID: " + modelConfigId);
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
@ -57,7 +67,7 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
* 统一调用大模型API
|
|
|
|
|
* 支持所有兼容 OpenAI Chat Completions API 的模型
|
|
|
|
|
*/
|
|
|
|
|
private Map<String, Object> callLlmApi(String prompt, LlmConfig config, String databaseName) throws Exception {
|
|
|
|
|
private Map<String, Object> callLlmApi(String prompt, LlmConfig config, String databaseName, String schemaInfo) throws Exception {
|
|
|
|
|
String apiKey = config.getApiKey().trim();
|
|
|
|
|
String url = config.getApiUrl().trim();
|
|
|
|
|
String modelName = config.getVersion().trim();
|
|
|
|
|
@ -74,7 +84,7 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
requestBody.put("model", modelName);
|
|
|
|
|
requestBody.put("messages", Arrays.asList(Map.of(
|
|
|
|
|
"role", "user",
|
|
|
|
|
"content", generatePrompt(prompt, databaseName)
|
|
|
|
|
"content", generatePrompt(prompt, databaseName, schemaInfo)
|
|
|
|
|
)));
|
|
|
|
|
requestBody.put("response_format", Map.of("type", "json_object"));
|
|
|
|
|
requestBody.put("temperature", 0.0);
|
|
|
|
|
@ -123,37 +133,49 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 生成统一的Prompt
|
|
|
|
|
* 生成统一的Prompt(包含数据库结构信息)
|
|
|
|
|
*/
|
|
|
|
|
private String generatePrompt(String prompt, String databaseName) {
|
|
|
|
|
return String.format(
|
|
|
|
|
"你是数据查询助手,需将用户请求转换为指定JSON格式。\n" +
|
|
|
|
|
"连接的数据库为\"%s\",仅生成该数据库的SQL。\n" +
|
|
|
|
|
"响应必须是单个有效的JSON对象,不包含任何额外文本或格式(如```json)。\n\n" +
|
|
|
|
|
"用户请求:\"%s\"\n\n" +
|
|
|
|
|
"规则:\n" +
|
|
|
|
|
"- 数据查询(可SQL回答):success=true,生成SQL、表格数据和图表数据\n" +
|
|
|
|
|
"- 非数据查询:success=false,表格数据用[\"Message\"]和[\"抱歉,仅支持数据查询\"]\n\n" +
|
|
|
|
|
"返回JSON格式:\n" +
|
|
|
|
|
"{\n" +
|
|
|
|
|
" \"success\": true/false,\n" +
|
|
|
|
|
" \"sqlQuery\": \"SQL语句\",\n" +
|
|
|
|
|
" \"tableData\": {\n" +
|
|
|
|
|
" \"headers\": [\"列1\", \"列2\"],\n" +
|
|
|
|
|
" \"rows\": [[\"值1\", \"值2\"]]\n" +
|
|
|
|
|
" },\n" +
|
|
|
|
|
" \"chartData\": {\n" +
|
|
|
|
|
" \"type\": \"bar/line/pie\",\n" +
|
|
|
|
|
" \"labels\": [\"标签1\"],\n" +
|
|
|
|
|
" \"datasets\": [{\n" +
|
|
|
|
|
" \"label\": \"数据标签\",\n" +
|
|
|
|
|
" \"data\": [1, 2, 3],\n" +
|
|
|
|
|
" \"backgroundColor\": \"rgba(22, 93, 255, 0.6)\"\n" +
|
|
|
|
|
" }]\n" +
|
|
|
|
|
" }\n" +
|
|
|
|
|
"}",
|
|
|
|
|
databaseName, prompt
|
|
|
|
|
);
|
|
|
|
|
private String generatePrompt(String prompt, String databaseName, String schemaInfo) {
|
|
|
|
|
StringBuilder promptBuilder = new StringBuilder();
|
|
|
|
|
|
|
|
|
|
promptBuilder.append("你是数据查询助手,需将用户请求转换为指定JSON格式。\n");
|
|
|
|
|
promptBuilder.append("连接的数据库为\"").append(databaseName).append("\",仅生成该数据库的SQL。\n");
|
|
|
|
|
promptBuilder.append("响应必须是单个有效的JSON对象,不包含任何额外文本或格式(如```json)。\n\n");
|
|
|
|
|
|
|
|
|
|
// 如果有表结构信息,添加到Prompt中
|
|
|
|
|
if (schemaInfo != null && !schemaInfo.isEmpty()) {
|
|
|
|
|
promptBuilder.append("=== 数据库表结构信息 ===\n");
|
|
|
|
|
promptBuilder.append(schemaInfo);
|
|
|
|
|
promptBuilder.append("\n请根据上述真实的表结构生成SQL,确保使用正确的表名和列名。\n");
|
|
|
|
|
promptBuilder.append("注意:必须使用实际存在的列名,不要猜测或假设列名。\n\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
promptBuilder.append("用户请求:\"").append(prompt).append("\"\n\n");
|
|
|
|
|
promptBuilder.append("规则:\n");
|
|
|
|
|
promptBuilder.append("- 数据查询(可SQL回答):success=true,生成SQL、表格数据和图表数据\n");
|
|
|
|
|
promptBuilder.append("- 非数据查询:success=false,表格数据用[\"Message\"]和[\"抱歉,仅支持数据查询\"]\n");
|
|
|
|
|
promptBuilder.append("- 必须使用上述表结构中实际存在的列名\n");
|
|
|
|
|
promptBuilder.append("- SQL语句必须符合MySQL语法\n\n");
|
|
|
|
|
promptBuilder.append("返回JSON格式:\n");
|
|
|
|
|
promptBuilder.append("{\n");
|
|
|
|
|
promptBuilder.append(" \"success\": true/false,\n");
|
|
|
|
|
promptBuilder.append(" \"sqlQuery\": \"SQL语句\",\n");
|
|
|
|
|
promptBuilder.append(" \"tableData\": {\n");
|
|
|
|
|
promptBuilder.append(" \"headers\": [\"列1\", \"列2\"],\n");
|
|
|
|
|
promptBuilder.append(" \"rows\": [[\"值1\", \"值2\"]]\n");
|
|
|
|
|
promptBuilder.append(" },\n");
|
|
|
|
|
promptBuilder.append(" \"chartData\": {\n");
|
|
|
|
|
promptBuilder.append(" \"type\": \"bar/line/pie\",\n");
|
|
|
|
|
promptBuilder.append(" \"labels\": [\"标签1\"],\n");
|
|
|
|
|
promptBuilder.append(" \"datasets\": [{\n");
|
|
|
|
|
promptBuilder.append(" \"label\": \"数据标签\",\n");
|
|
|
|
|
promptBuilder.append(" \"data\": [1, 2, 3],\n");
|
|
|
|
|
promptBuilder.append(" \"backgroundColor\": \"rgba(22, 93, 255, 0.6)\"\n");
|
|
|
|
|
promptBuilder.append(" }]\n");
|
|
|
|
|
promptBuilder.append(" }\n");
|
|
|
|
|
promptBuilder.append("}");
|
|
|
|
|
|
|
|
|
|
return promptBuilder.toString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
@ -172,4 +194,99 @@ public class LlmServiceImpl implements LlmService {
|
|
|
|
|
throw new RuntimeException("解析模型响应失败: " + e.getMessage(), e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public Map<String, Object> generateQueryWithConnection(String prompt, String modelConfigId, String databaseName, DbConnection dbConnection) {
|
|
|
|
|
try {
|
|
|
|
|
// 获取数据库表结构
|
|
|
|
|
System.out.println("✓ 开始获取数据库表结构信息...");
|
|
|
|
|
String schemaInfo = databaseSchemaService.getDatabaseSchema(dbConnection);
|
|
|
|
|
System.out.println("✓ 已获取数据库表结构信息");
|
|
|
|
|
|
|
|
|
|
// 根据配置ID从数据库获取模型配置
|
|
|
|
|
LlmConfig config = llmConfigService.getById(Long.valueOf(modelConfigId));
|
|
|
|
|
if (config == null) {
|
|
|
|
|
throw new RuntimeException("模型配置不存在,ID: " + modelConfigId);
|
|
|
|
|
}
|
|
|
|
|
if (config.getIsDisabled() == 1) {
|
|
|
|
|
throw new RuntimeException("该模型配置已被禁用");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 使用包含表结构的 prompt 调用大模型
|
|
|
|
|
return callLlmApiWithSchema(prompt, config, databaseName, schemaInfo);
|
|
|
|
|
} catch (NumberFormatException e) {
|
|
|
|
|
throw new RuntimeException("无效的模型配置ID: " + modelConfigId);
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
throw new RuntimeException("模型调用失败: " + e.getMessage(), e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 调用大模型API(包含表结构信息)
|
|
|
|
|
*/
|
|
|
|
|
private Map<String, Object> callLlmApiWithSchema(String prompt, LlmConfig config, String databaseName, String schemaInfo) throws Exception {
|
|
|
|
|
String apiKey = config.getApiKey().trim();
|
|
|
|
|
String url = config.getApiUrl().trim();
|
|
|
|
|
String modelName = config.getVersion().trim();
|
|
|
|
|
|
|
|
|
|
// 打印调试信息
|
|
|
|
|
System.out.println("=== LLM API 调用信息(含表结构) ===");
|
|
|
|
|
System.out.println("配置名称: " + config.getName() + " (ID: " + config.getId() + ")");
|
|
|
|
|
System.out.println("API URL: " + url);
|
|
|
|
|
System.out.println("模型名称: " + modelName);
|
|
|
|
|
|
|
|
|
|
// 构建包含表结构的 prompt
|
|
|
|
|
String enhancedPrompt = generatePrompt(prompt, databaseName, schemaInfo);
|
|
|
|
|
|
|
|
|
|
// 构建请求体(OpenAI Chat Completions API 格式)
|
|
|
|
|
JSONObject requestBody = new JSONObject();
|
|
|
|
|
requestBody.put("model", modelName);
|
|
|
|
|
requestBody.put("messages", Arrays.asList(Map.of(
|
|
|
|
|
"role", "user",
|
|
|
|
|
"content", enhancedPrompt
|
|
|
|
|
)));
|
|
|
|
|
requestBody.put("response_format", Map.of("type", "json_object"));
|
|
|
|
|
requestBody.put("temperature", 0.0);
|
|
|
|
|
|
|
|
|
|
System.out.println("发送请求到大模型...");
|
|
|
|
|
|
|
|
|
|
// 发送HTTP请求
|
|
|
|
|
HttpRequest request = HttpRequest.newBuilder()
|
|
|
|
|
.uri(URI.create(url))
|
|
|
|
|
.header("Content-Type", "application/json")
|
|
|
|
|
.header("Authorization", "Bearer " + apiKey)
|
|
|
|
|
.POST(HttpRequest.BodyPublishers.ofString(requestBody.toJSONString()))
|
|
|
|
|
.timeout(Duration.ofSeconds(config.getTimeout() != null ? config.getTimeout() / 1000 : 60))
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
|
|
|
|
|
|
|
|
|
System.out.println("响应状态码: " + response.statusCode());
|
|
|
|
|
System.out.println("=========================");
|
|
|
|
|
|
|
|
|
|
if (response.statusCode() != 200) {
|
|
|
|
|
throw new RuntimeException("API调用失败: " + response.statusCode() + ", 响应: " + response.body());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
JSONObject jsonResponse = JSON.parseObject(response.body());
|
|
|
|
|
|
|
|
|
|
// 解析响应(OpenAI格式)
|
|
|
|
|
if (!jsonResponse.containsKey("choices") || jsonResponse.getJSONArray("choices").isEmpty()) {
|
|
|
|
|
throw new RuntimeException("API响应格式错误:缺少choices字段");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
JSONObject choice = jsonResponse.getJSONArray("choices").getJSONObject(0);
|
|
|
|
|
if (!choice.containsKey("message")) {
|
|
|
|
|
throw new RuntimeException("API响应格式错误:缺少message字段");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
String content = choice.getJSONObject("message").getString("content");
|
|
|
|
|
if (content == null || content.isEmpty()) {
|
|
|
|
|
throw new RuntimeException("API返回内容为空");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 清理可能的markdown代码块标记
|
|
|
|
|
String cleanedContent = content.replaceAll("^```json\\n|```$", "").trim();
|
|
|
|
|
return parseJsonResponse(cleanedContent);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|