|
|
|
package ai_model_cli
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"goskeleton/app/global/variable"
|
|
|
|
"os"
|
|
|
|
|
|
|
|
"goskeleton/app/global/consts"
|
|
|
|
|
|
|
|
"github.com/baidubce/bce-qianfan-sdk/go/qianfan"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
)
|
|
|
|
|
|
|
|
func RequestStyle(c *gin.Context) (interface{}, error) {
|
|
|
|
|
|
|
|
// userMsg := c.PostForm("user_input")
|
|
|
|
userMsg := c.GetString(consts.ValidatorPrefix + "user_input")
|
|
|
|
|
|
|
|
qianfan.GetConfig().AccessKey = variable.ConfigYml.GetString("BaiduCE.QianFanAccessKey")
|
|
|
|
qianfan.GetConfig().SecretKey = variable.ConfigYml.GetString("BaiduCE.QianFanSecretKey")
|
|
|
|
|
|
|
|
chat := qianfan.NewChatCompletion(
|
|
|
|
qianfan.WithModel("ERNIE-4.0-8K"),
|
|
|
|
)
|
|
|
|
|
|
|
|
chatHistory := []qianfan.ChatCompletionMessage{}
|
|
|
|
|
|
|
|
// 读取prompt文件
|
|
|
|
systemMsgPath := variable.ConfigYml.GetString("BaiduCE.StyleGeneratePromptPath")
|
|
|
|
// 读取文件内容
|
|
|
|
prompt, err := os.ReadFile(variable.BasePath + systemMsgPath)
|
|
|
|
if err != nil || len(prompt) == 0 {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("读取提示词文件失败: %v", err))
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// add user history to chat history
|
|
|
|
userHistory, exist := c.Get(consts.ValidatorPrefix + "chat_history")
|
|
|
|
if exist && userHistory != nil {
|
|
|
|
// check if userHistory is of type []struct{Role string;Content string}
|
|
|
|
historySlice, ok := userHistory.([]interface{})
|
|
|
|
if !ok || len(historySlice)%2 != 0 {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v", userHistory))
|
|
|
|
return nil, fmt.Errorf("用户历史对话格式错误")
|
|
|
|
}
|
|
|
|
|
|
|
|
// convert userHistory to []qianfan.ChatCompletionMessage
|
|
|
|
var chatHistoryConverted []qianfan.ChatCompletionMessage
|
|
|
|
for _, item := range historySlice {
|
|
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
|
|
role, roleOk := itemMap["role"].(string)
|
|
|
|
content, contentOk := itemMap["content"].(string)
|
|
|
|
if roleOk && contentOk {
|
|
|
|
chatHistoryConverted = append(chatHistoryConverted, qianfan.ChatCompletionMessage{
|
|
|
|
Role: role,
|
|
|
|
Content: content,
|
|
|
|
})
|
|
|
|
} else {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v\nrole 或 content 类型断言失败", userHistory))
|
|
|
|
return nil, fmt.Errorf("用户历史对话格式错误")
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v\n无法将 item 转换为 map[string]interface{}", userHistory))
|
|
|
|
return nil, fmt.Errorf("用户历史对话格式错误")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(chatHistoryConverted) > 0 && len(chatHistoryConverted)%2 == 0 {
|
|
|
|
chatHistory = append(chatHistory, chatHistoryConverted...)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// add user input to chat history
|
|
|
|
chatHistory = append(chatHistory, qianfan.ChatCompletionUserMessage(userMsg))
|
|
|
|
|
|
|
|
response, err := chat.Do(context.TODO(), &qianfan.ChatCompletionRequest{System: string(prompt), Messages: chatHistory})
|
|
|
|
if err != nil {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("对话失败: %v", err))
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return response.Result, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func RequestStyleStream(c *gin.Context) error {
|
|
|
|
userMsg := c.GetString(consts.ValidatorPrefix + "user_input")
|
|
|
|
|
|
|
|
qianfan.GetConfig().AccessKey = variable.ConfigYml.GetString("BaiduCE.QianFanAccessKey")
|
|
|
|
qianfan.GetConfig().SecretKey = variable.ConfigYml.GetString("BaiduCE.QianFanSecretKey")
|
|
|
|
|
|
|
|
chat := qianfan.NewChatCompletion(
|
|
|
|
qianfan.WithModel("ERNIE-4.0-8K"),
|
|
|
|
)
|
|
|
|
|
|
|
|
chatHistory := []qianfan.ChatCompletionMessage{}
|
|
|
|
|
|
|
|
systemMsgPath := variable.ConfigYml.GetString("BaiduCE.StyleGeneratePromptPath")
|
|
|
|
prompt, err := os.ReadFile(variable.BasePath + systemMsgPath)
|
|
|
|
if err != nil || len(prompt) == 0 {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("读取提示词文件失败: %v", err))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
userHistory, exist := c.Get(consts.ValidatorPrefix + "chat_history")
|
|
|
|
if exist && userHistory != nil {
|
|
|
|
historySlice, ok := userHistory.([]interface{})
|
|
|
|
if !ok || len(historySlice)%2 != 0 {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v", userHistory))
|
|
|
|
return fmt.Errorf("用户历史对话格式错误")
|
|
|
|
}
|
|
|
|
|
|
|
|
var chatHistoryConverted []qianfan.ChatCompletionMessage
|
|
|
|
for _, item := range historySlice {
|
|
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
|
|
role, roleOk := itemMap["role"].(string)
|
|
|
|
content, contentOk := itemMap["content"].(string)
|
|
|
|
if roleOk && contentOk {
|
|
|
|
chatHistoryConverted = append(chatHistoryConverted, qianfan.ChatCompletionMessage{
|
|
|
|
Role: role,
|
|
|
|
Content: content,
|
|
|
|
})
|
|
|
|
} else {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v\nrole 或 content 类型断言失败", userHistory))
|
|
|
|
return fmt.Errorf("用户历史对话格式错误")
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v\n无法将 item 转换为 map[string]interface{}", userHistory))
|
|
|
|
return fmt.Errorf("用户历史对话格式错误")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(chatHistoryConverted) > 0 && len(chatHistoryConverted)%2 == 0 {
|
|
|
|
chatHistory = append(chatHistory, chatHistoryConverted...)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
chatHistory = append(chatHistory, qianfan.ChatCompletionUserMessage(userMsg))
|
|
|
|
|
|
|
|
stream, err := chat.Stream(context.TODO(), &qianfan.ChatCompletionRequest{System: string(prompt), Messages: chatHistory})
|
|
|
|
if err != nil {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("对话失败: %v", err))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer stream.Close()
|
|
|
|
|
|
|
|
c.Writer.Flush()
|
|
|
|
defer c.Writer.Flush()
|
|
|
|
for {
|
|
|
|
response, err := stream.Recv()
|
|
|
|
if response.IsEnd {
|
|
|
|
break // 流结束,退出循环
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("接收流失败: %v", err))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
// 将结果写入到响应体
|
|
|
|
if _,err:=fmt.Fprintf(c.Writer,"%s",response.Result);err!=nil{
|
|
|
|
variable.ZapLog.Error(fmt.Sprintf("写入流失败: %v", err))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// 立即刷新缓冲区,以确保数据立即发送到客户端
|
|
|
|
c.Writer.Flush()
|
|
|
|
}
|
|
|
|
return nil // 正常结束,返回 nil
|
|
|
|
}
|