You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.0 KiB
64 lines
2.0 KiB
2 months ago
|
package ai_model_cli
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"goskeleton/app/global/variable"
|
||
|
"os"
|
||
|
|
||
|
"goskeleton/app/global/consts"
|
||
|
|
||
|
"github.com/baidubce/bce-qianfan-sdk/go/qianfan"
|
||
|
"github.com/gin-gonic/gin"
|
||
|
)
|
||
|
|
||
|
func RequestStyle(c *gin.Context) (interface{}, error) {
|
||
|
|
||
|
// userMsg := c.PostForm("user_input")
|
||
|
userMsg:=c.GetString(consts.ValidatorPrefix+"user_input")
|
||
|
|
||
|
qianfan.GetConfig().AccessKey = variable.ConfigYml.GetString("BaiduCE.QianFanAccessKey")
|
||
|
qianfan.GetConfig().SecretKey = variable.ConfigYml.GetString("BaiduCE.QianFanSecretKey")
|
||
|
|
||
|
chat := qianfan.NewChatCompletion(
|
||
|
qianfan.WithModel("ERNIE-4.0-8K"),
|
||
|
)
|
||
|
|
||
|
chatHistory := []qianfan.ChatCompletionMessage{}
|
||
|
|
||
|
// 读取prompt文件
|
||
|
systemMsgPath := variable.ConfigYml.GetString("BaiduCE.StyleGeneratePromptPath")
|
||
|
// 读取文件内容
|
||
|
prompt, err := os.ReadFile(variable.BasePath+systemMsgPath)
|
||
|
if err != nil || len(prompt) == 0 {
|
||
|
variable.ZapLog.Error(fmt.Sprintf("读取提示词文件失败: %v", err))
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// add user history to chat history
|
||
|
userHistory,exist := c.Get(consts.ValidatorPrefix+"chat_history")
|
||
|
if exist&&userHistory!=nil{
|
||
|
// TODO: check if userHistory is of type []struct{Role string;Content string}
|
||
|
userHistory := userHistory.([]struct{Role string;Content string})
|
||
|
if len(userHistory)%2!=0{
|
||
|
variable.ZapLog.Error(fmt.Sprintf("用户历史对话格式错误: %v", userHistory))
|
||
|
return nil, fmt.Errorf("用户历史对话格式错误")
|
||
|
}
|
||
|
for _,msg := range userHistory{
|
||
|
chatHistory = append(chatHistory, qianfan.ChatCompletionMessage{Role:msg.Role,Content:msg.Content})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// add user input to chat history
|
||
|
chatHistory = append(chatHistory, qianfan.ChatCompletionUserMessage(userMsg))
|
||
|
|
||
|
// define a stream chat client
|
||
|
response,err:=chat.Do(context.TODO(),&qianfan.ChatCompletionRequest{System: string(prompt),Messages: chatHistory})
|
||
|
if err != nil {
|
||
|
variable.ZapLog.Error(fmt.Sprintf("对话失败: %v", err))
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return response.Result, nil
|
||
|
}
|