You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
git-test/src/main/java/net/micode/notes/tool/AIService.java

410 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package net.micode.notes.tool;
import android.content.Context;
import android.graphics.Bitmap;
import android.util.Base64;
import android.util.Log;
import org.json.JSONException;
import org.json.JSONObject;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
/**
* AIService - AI服务类
* <p>
* 用于处理与AI相关的服务调用如豆包API
* </p>
*/
public class AIService {
private static final String TAG = "AIService";
private static final String DOUBAO_API_URL = "https://ark.cn-beijing.volces.com/api/v3/responses";
private static final String API_KEY = "ee5fb4c7-ea14-4481-ac23-4b0e82907850";
private static final String SECRET_ACCESS_KEY = "";
/**
* 提取图片内容
* @param bitmap 图片bitmap
* @param callback 回调接口
*/
public static void extractImageContent(final Bitmap bitmap, final ExtractImageContentCallback callback) {
new Thread(new Runnable() {
@Override
public void run() {
try {
Log.d(TAG, "Starting image content extraction...");
// 检查bitmap
if (bitmap == null) {
Log.e(TAG, "Bitmap is null");
callback.onFailure("Bitmap is null");
return;
}
Log.d(TAG, "Bitmap width: " + bitmap.getWidth() + ", height: " + bitmap.getHeight());
// 将bitmap转换为Base64
Log.d(TAG, "Converting bitmap to base64...");
String base64Image = bitmapToBase64(bitmap);
if (base64Image == null) {
Log.e(TAG, "Failed to convert bitmap to base64");
callback.onFailure("Failed to convert bitmap to base64");
return;
}
Log.d(TAG, "Base64 conversion successful, length: " + base64Image.length());
// 构建请求体
Log.d(TAG, "Building request body...");
JSONObject requestBody = new JSONObject();
requestBody.put("model", "ep-20260127214554-frsrr"); // 新的推理接入点ID
// 创建input数组
org.json.JSONArray input = new org.json.JSONArray();
// 创建user input
JSONObject userInput = new JSONObject();
userInput.put("role", "user");
// 创建content数组
org.json.JSONArray contentArray = new org.json.JSONArray();
// 添加图片部分
JSONObject imageContent = new JSONObject();
imageContent.put("type", "input_image");
imageContent.put("image_url", "data:image/jpeg;base64," + base64Image);
contentArray.put(imageContent);
// 添加文本部分
JSONObject textContent = new JSONObject();
textContent.put("type", "input_text");
textContent.put("text", "请提取这张图片中的所有文字和结构化数据,包括表格、列表等信息,清晰准确地格式化提取的内容。");
contentArray.put(textContent);
userInput.put("content", contentArray);
input.put(userInput);
requestBody.put("input", input);
// 发送请求
String requestBodyString = requestBody.toString();
Log.d(TAG, "Request body length: " + requestBodyString.length());
Log.d(TAG, "Request body (first 1000 chars): " + (requestBodyString.length() > 1000 ? requestBodyString.substring(0, 1000) + "..." : requestBodyString));
Log.d(TAG, "Sending POST request to: " + DOUBAO_API_URL);
String response = sendPostRequest(DOUBAO_API_URL, requestBodyString);
if (response == null) {
Log.e(TAG, "Failed to get response from Doubao API");
callback.onFailure("Failed to get response from Doubao API");
return;
}
Log.d(TAG, "Got response from Doubao API, length: " + response.length());
Log.d(TAG, "Response content: " + response);
// 解析响应
Log.d(TAG, "Parsing response...");
JSONObject responseJson = new JSONObject(response);
// 检查响应格式
if (responseJson.has("output")) {
Log.d(TAG, "Response has output field");
try {
// 尝试作为数组处理(新格式)
org.json.JSONArray outputArray = responseJson.getJSONArray("output");
Log.d(TAG, "Output is an array, length: " + outputArray.length());
// 遍历数组找到包含文本的message
String extractedText = "";
for (int i = 0; i < outputArray.length(); i++) {
JSONObject item = outputArray.getJSONObject(i);
Log.d(TAG, "Output item " + i + ": " + item.toString());
// 检查是否是message类型
if (item.has("type") && "message".equals(item.getString("type"))) {
Log.d(TAG, "Found message item");
if (item.has("content")) {
org.json.JSONArray messageContentArray = item.getJSONArray("content");
for (int j = 0; j < messageContentArray.length(); j++) {
JSONObject contentItem = messageContentArray.getJSONObject(j);
if (contentItem.has("type") && "output_text".equals(contentItem.getString("type"))) {
extractedText = contentItem.getString("text");
Log.d(TAG, "Got text from response: " + extractedText);
callback.onSuccess(extractedText);
return;
}
}
}
}
// 检查是否有role字段为assistant
else if (item.has("role") && "assistant".equals(item.getString("role"))) {
Log.d(TAG, "Found assistant item");
if (item.has("content")) {
org.json.JSONArray assistantContentArray = item.getJSONArray("content");
for (int j = 0; j < assistantContentArray.length(); j++) {
JSONObject contentItem = assistantContentArray.getJSONObject(j);
if (contentItem.has("type") && "output_text".equals(contentItem.getString("type"))) {
extractedText = contentItem.getString("text");
Log.d(TAG, "Got text from assistant response: " + extractedText);
callback.onSuccess(extractedText);
return;
}
}
}
}
}
// 如果没有找到文本,尝试其他方式
if (extractedText.isEmpty()) {
Log.e(TAG, "No text found in output array");
callback.onFailure("No text found in output array");
}
} catch (JSONException e) {
// 如果不是数组,尝试作为对象处理(旧格式)
Log.d(TAG, "Output is not an array, trying as object: " + e.getMessage());
try {
JSONObject outputObj = responseJson.getJSONObject("output");
if (outputObj.has("text")) {
String content = outputObj.getString("text");
Log.d(TAG, "Got text from response object: " + content);
callback.onSuccess(content);
} else if (outputObj.has("content")) {
String content = outputObj.getString("content");
Log.d(TAG, "Got content from response object: " + content);
callback.onSuccess(content);
} else {
Log.e(TAG, "No text or content in response object: " + outputObj.toString());
callback.onFailure("No text or content in response");
}
} catch (JSONException ex) {
Log.e(TAG, "Error parsing output: " + ex.getMessage());
callback.onFailure("Error parsing output: " + ex.getMessage());
}
}
} else if (responseJson.has("choices")) {
// 兼容旧格式
Log.d(TAG, "Response has choices field");
org.json.JSONArray choices = responseJson.getJSONArray("choices");
if (choices.length() > 0) {
JSONObject choice = choices.getJSONObject(0);
if (choice.has("message")) {
JSONObject message = choice.getJSONObject("message");
if (message.has("content")) {
String content = message.getString("content");
Log.d(TAG, "Got content from choices: " + content);
callback.onSuccess(content);
} else {
Log.e(TAG, "No content in message: " + message.toString());
callback.onFailure("No content in message");
}
} else {
Log.e(TAG, "No message in choice: " + choice.toString());
callback.onFailure("No message in choice");
}
} else {
Log.e(TAG, "No choices in response");
callback.onFailure("No choices in response");
}
} else if (responseJson.has("error")) {
// 处理错误响应
Log.e(TAG, "API returned error: " + responseJson.toString());
JSONObject error = responseJson.getJSONObject("error");
String errorMessage = error.getString("message");
callback.onFailure("API error: " + errorMessage);
} else {
Log.e(TAG, "Unexpected response format: " + responseJson.toString());
callback.onFailure("Unexpected response format: " + responseJson.toString());
}
} catch (JSONException e) {
Log.e(TAG, "JSONException: " + e.getMessage());
e.printStackTrace();
callback.onFailure("JSON error: " + e.getMessage());
} catch (Exception e) {
Log.e(TAG, "Exception: " + e.getMessage());
e.printStackTrace();
callback.onFailure("Error: " + e.getMessage());
}
}
}).start();
}
/**
* 将Bitmap转换为Base64字符串
* @param bitmap 图片bitmap
* @return Base64字符串
*/
private static String bitmapToBase64(Bitmap bitmap) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
bitmap.compress(Bitmap.CompressFormat.JPEG, 80, byteArrayOutputStream);
byte[] byteArray = byteArrayOutputStream.toByteArray();
try {
byteArrayOutputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
return Base64.encodeToString(byteArray, Base64.NO_WRAP);
}
/**
* 发送POST请求
* @param urlString URL字符串
* @param requestBody 请求体
* @return 响应字符串
*/
private static String sendPostRequest(String urlString, String requestBody) {
try {
Log.d(TAG, "Sending POST request to: " + urlString);
Log.d(TAG, "Request body length: " + requestBody.length());
Log.d(TAG, "Request body (first 500 chars): " + (requestBody.length() > 500 ? requestBody.substring(0, 500) + "..." : requestBody));
URL url = new URL(urlString);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
// 设置请求头
connection.setRequestProperty("Content-Type", "application/json");
connection.setRequestProperty("Authorization", "Bearer " + API_KEY);
connection.setRequestProperty("X-TT-LOGID", System.currentTimeMillis() + "");
connection.setRequestProperty("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36");
connection.setDoOutput(true);
connection.setConnectTimeout(30000); // 设置连接超时为30秒
connection.setReadTimeout(30000); // 设置读取超时为30秒
// 写入请求体
Log.d(TAG, "Writing request body...");
OutputStream outputStream = connection.getOutputStream();
outputStream.write(requestBody.getBytes(StandardCharsets.UTF_8));
outputStream.flush();
outputStream.close();
Log.d(TAG, "Request body written successfully");
// 读取响应
Log.d(TAG, "Reading response...");
int responseCode = connection.getResponseCode();
Log.d(TAG, "HTTP response code: " + responseCode);
// 读取所有响应头
Log.d(TAG, "Response headers:");
java.util.Map<String, java.util.List<String>> headers = connection.getHeaderFields();
for (String key : headers.keySet()) {
if (key != null) {
Log.d(TAG, key + ": " + headers.get(key));
}
}
if (responseCode == HttpURLConnection.HTTP_OK) {
Log.d(TAG, "HTTP OK, reading response body...");
InputStream inputStream = connection.getInputStream();
byte[] buffer = new byte[1024];
int bytesRead;
ByteArrayOutputStream responseStream = new ByteArrayOutputStream();
while ((bytesRead = inputStream.read(buffer)) != -1) {
responseStream.write(buffer, 0, bytesRead);
}
String responseString = responseStream.toString(StandardCharsets.UTF_8.name());
responseStream.close();
inputStream.close();
connection.disconnect();
Log.d(TAG, "API response length: " + responseString.length());
Log.d(TAG, "API response (first 500 chars): " + (responseString.length() > 500 ? responseString.substring(0, 500) + "..." : responseString));
return responseString;
} else {
// 读取错误响应
Log.e(TAG, "HTTP error, reading error response...");
InputStream errorStream = connection.getErrorStream();
if (errorStream != null) {
byte[] buffer = new byte[1024];
int bytesRead;
ByteArrayOutputStream errorResponseStream = new ByteArrayOutputStream();
while ((bytesRead = errorStream.read(buffer)) != -1) {
errorResponseStream.write(buffer, 0, bytesRead);
}
String errorResponse = errorResponseStream.toString(StandardCharsets.UTF_8.name());
errorResponseStream.close();
errorStream.close();
Log.e(TAG, "HTTP error: " + responseCode + ", Error response: " + errorResponse);
} else {
Log.e(TAG, "HTTP error code: " + responseCode + ", No error stream available");
}
connection.disconnect();
return null;
}
} catch (java.net.SocketTimeoutException e) {
Log.e(TAG, "Socket timeout error: " + e.getMessage());
e.printStackTrace();
return null;
} catch (java.net.ConnectException e) {
Log.e(TAG, "Connection error: " + e.getMessage());
e.printStackTrace();
return null;
} catch (java.io.IOException e) {
Log.e(TAG, "IO error: " + e.getMessage());
e.printStackTrace();
return null;
} catch (Exception e) {
Log.e(TAG, "Error sending POST request: " + e.getMessage());
e.printStackTrace();
return null;
}
}
/**
* 提取图片内容回调接口
*/
public interface ExtractImageContentCallback {
void onSuccess(String extractedContent);
void onFailure(String errorMessage);
}
/**
* 测试API连接
*/
public static void testApiConnection(final ExtractImageContentCallback callback) {
new Thread(new Runnable() {
@Override
public void run() {
try {
// 构建测试请求体
JSONObject requestBody = new JSONObject();
requestBody.put("model", "ep-20260127214554-frsrr");
org.json.JSONArray input = new org.json.JSONArray();
JSONObject userInput = new JSONObject();
userInput.put("role", "user");
org.json.JSONArray contentArray = new org.json.JSONArray();
JSONObject textContent = new JSONObject();
textContent.put("type", "input_text");
textContent.put("text", "Hello, test connection");
contentArray.put(textContent);
userInput.put("content", contentArray);
input.put(userInput);
requestBody.put("input", input);
Log.d(TAG, "Testing API connection...");
String response = sendPostRequest(DOUBAO_API_URL, requestBody.toString());
if (response != null) {
Log.d(TAG, "API connection test successful: " + response);
callback.onSuccess("API connection test successful");
} else {
Log.e(TAG, "API connection test failed");
callback.onFailure("API connection test failed");
}
} catch (Exception e) {
Log.e(TAG, "Error testing API connection: " + e.getMessage());
callback.onFailure("Error testing API connection: " + e.getMessage());
}
}
}).start();
}
}