feat: appendConversationHistory for ai planner
This commit is contained in:
parent
bbc05513f9
commit
8a3b6b5c4c
|
@ -1 +1 @@
|
|||
v5.0.0-beta-2503202053
|
||||
v5.0.0-beta-2503220006
|
||||
|
|
|
@ -21,9 +21,9 @@ func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
|||
|
||||
// PlanningOptions represents the input options for planning
|
||||
type PlanningOptions struct {
|
||||
UserInstruction string `json:"user_instruction"`
|
||||
ConversationHistory []*schema.Message `json:"conversation_history"`
|
||||
Size types.Size `json:"size"`
|
||||
UserInstruction string `json:"user_instruction"` // append to system prompt
|
||||
Message *schema.Message `json:"message"`
|
||||
Size types.Size `json:"size"`
|
||||
}
|
||||
|
||||
// PlanningResult represents the result of planning
|
||||
|
|
|
@ -47,25 +47,41 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
|
|||
}
|
||||
|
||||
type Planner struct {
|
||||
ctx context.Context
|
||||
model model.ChatModel
|
||||
parser *ActionParser
|
||||
ctx context.Context
|
||||
model model.ChatModel
|
||||
parser *ActionParser
|
||||
history []*schema.Message // conversation history
|
||||
}
|
||||
|
||||
// Call performs UI planning using Vision Language Model
|
||||
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
||||
|
||||
// validate input parameters
|
||||
if err := validateInput(opts); err != nil {
|
||||
return nil, errors.Wrap(err, "validate input parameters failed")
|
||||
}
|
||||
|
||||
// call VLM service
|
||||
resp, err := p.callVLMService(opts)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "call VLM service failed")
|
||||
// prepare prompt
|
||||
if len(p.history) == 0 {
|
||||
// add system message
|
||||
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
||||
p.history = []*schema.Message{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
// append user image message
|
||||
p.appendConversationHistory(opts.Message)
|
||||
|
||||
// call model service, generate response
|
||||
logRequest(p.history)
|
||||
log.Info().Msg("calling model service...")
|
||||
resp, err := p.model.Generate(p.ctx, p.history)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request model service failed: %w", err)
|
||||
}
|
||||
logResponse(resp)
|
||||
|
||||
// parse result
|
||||
result, err := p.parseResult(resp, opts.Size)
|
||||
|
@ -73,10 +89,12 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
|||
return nil, errors.Wrap(err, "parse result failed")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Interface("summary", result.ActionSummary).
|
||||
Interface("actions", result.NextActions).
|
||||
Msg("get VLM planning result")
|
||||
// append assistant message
|
||||
p.appendConversationHistory(&schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: result.ActionSummary,
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
@ -85,57 +103,107 @@ func validateInput(opts *PlanningOptions) error {
|
|||
return ErrEmptyInstruction
|
||||
}
|
||||
|
||||
if len(opts.ConversationHistory) == 0 {
|
||||
if opts.Message == nil {
|
||||
return ErrNoConversationHistory
|
||||
}
|
||||
|
||||
// ensure at least one image URL
|
||||
hasImageURL := false
|
||||
for _, msg := range opts.ConversationHistory {
|
||||
if msg.Role == "user" {
|
||||
// check MultiContent
|
||||
if len(msg.MultiContent) > 0 {
|
||||
for _, content := range msg.MultiContent {
|
||||
if content.Type == "image_url" && content.ImageURL != nil {
|
||||
hasImageURL = true
|
||||
break
|
||||
}
|
||||
if opts.Message.Role == schema.User {
|
||||
// check MultiContent
|
||||
if len(opts.Message.MultiContent) > 0 {
|
||||
for _, content := range opts.Message.MultiContent {
|
||||
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
|
||||
return ErrInvalidImageData
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasImageURL {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasImageURL {
|
||||
return ErrInvalidImageData
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callVLMService makes the actual call to the VLM service
|
||||
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
|
||||
log.Info().Msg("calling VLM service...")
|
||||
|
||||
// prepare prompt
|
||||
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
||||
messages := []*schema.Message{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
func logRequest(messages []*schema.Message) {
|
||||
msgs := make([]*schema.Message, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
msg := &schema.Message{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.Content != "" {
|
||||
msg.Content = message.Content
|
||||
} else if len(message.MultiContent) > 0 {
|
||||
for _, mc := range message.MultiContent {
|
||||
switch mc.Type {
|
||||
case schema.ChatMessagePartTypeImageURL:
|
||||
// Create a copy of the ImageURL to avoid modifying the original message
|
||||
imageURLCopy := *mc.ImageURL
|
||||
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
|
||||
imageURLCopy.URL = "<data:image/base64...>"
|
||||
}
|
||||
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
|
||||
Type: mc.Type,
|
||||
ImageURL: &imageURLCopy,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
messages = append(messages, opts.ConversationHistory...)
|
||||
log.Debug().Interface("messages", msgs).Msg("log request messages")
|
||||
}
|
||||
|
||||
// generate response
|
||||
resp, err := p.model.Generate(p.ctx, messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
||||
func logResponse(resp *schema.Message) {
|
||||
log.Info().Str("role", string(resp.Role)).
|
||||
Str("content", resp.Content).Msg("log response message")
|
||||
}
|
||||
|
||||
// appendConversationHistory adds a message to the conversation history
|
||||
func (p *Planner) appendConversationHistory(msg *schema.Message) {
|
||||
// for user image message:
|
||||
// - keep at most 4 user image messages
|
||||
// - delete the oldest user image message when the limit is reached
|
||||
if msg.Role == schema.User {
|
||||
// get all existing user messages
|
||||
userImgCount := 0
|
||||
firstUserImgIndex := -1
|
||||
|
||||
// calculate the number of user messages and find the index of the first user message
|
||||
for i, item := range p.history {
|
||||
if item.Role == schema.User {
|
||||
userImgCount++
|
||||
if firstUserImgIndex == -1 {
|
||||
firstUserImgIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if there are already 4 user messages, delete the first one before adding the new message
|
||||
if userImgCount >= 4 && firstUserImgIndex >= 0 {
|
||||
// delete the first user message
|
||||
p.history = append(
|
||||
p.history[:firstUserImgIndex],
|
||||
p.history[firstUserImgIndex+1:]...,
|
||||
)
|
||||
}
|
||||
// add the new user message to the history
|
||||
p.history = append(p.history, msg)
|
||||
}
|
||||
|
||||
// for assistant message:
|
||||
// - keep at most the last 10 assistant messages
|
||||
if msg.Role == schema.Assistant {
|
||||
// add the new assistant message to the history
|
||||
p.history = append(p.history, msg)
|
||||
|
||||
// if there are more than 10 assistant messages, remove the oldest ones
|
||||
assistantMsgCount := 0
|
||||
for i := len(p.history) - 1; i >= 0; i-- {
|
||||
if p.history[i].Role == schema.Assistant {
|
||||
assistantMsgCount++
|
||||
if assistantMsgCount > 10 {
|
||||
p.history = append(p.history[:i], p.history[i+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Info().Str("content", resp.Content).Msg("get VLM response")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
|
||||
|
@ -151,6 +219,10 @@ func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningRe
|
|||
return nil, errors.Wrap(err, "process VLM response failed")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Interface("summary", result.ActionSummary).
|
||||
Interface("actions", result.NextActions).
|
||||
Msg("get VLM planning result")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestVLMPlanning(t *testing.T) {
|
||||
imageBase64, size, err := loadImage("testdata/llk_3.png")
|
||||
imageBase64, size, err := loadImage("testdata/llk_1.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
|
||||
|
@ -33,15 +33,13 @@ func TestVLMPlanning(t *testing.T) {
|
|||
|
||||
opts := &PlanningOptions{
|
||||
UserInstruction: userInstruction,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
Message: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -108,15 +106,13 @@ func TestXHSPlanning(t *testing.T) {
|
|||
|
||||
opts := &PlanningOptions{
|
||||
UserInstruction: userInstruction,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
Message: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -182,15 +178,13 @@ func TestValidateInput(t *testing.T) {
|
|||
name: "valid input",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击继续使用按钮",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
Message: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -203,11 +197,9 @@ func TestValidateInput(t *testing.T) {
|
|||
name: "empty instruction",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
Content: "",
|
||||
},
|
||||
Message: &schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "",
|
||||
},
|
||||
Size: size,
|
||||
},
|
||||
|
@ -216,8 +208,8 @@ func TestValidateInput(t *testing.T) {
|
|||
{
|
||||
name: "empty conversation history",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击立即卸载按钮",
|
||||
ConversationHistory: []*schema.Message{},
|
||||
UserInstruction: "点击立即卸载按钮",
|
||||
Message: &schema.Message{},
|
||||
},
|
||||
wantErr: ErrNoConversationHistory,
|
||||
},
|
||||
|
@ -225,11 +217,9 @@ func TestValidateInput(t *testing.T) {
|
|||
name: "invalid image data",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击继续使用按钮",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
Content: "no image",
|
||||
},
|
||||
Message: &schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "no image",
|
||||
},
|
||||
Size: size,
|
||||
},
|
||||
|
|
|
@ -6,8 +6,41 @@ import (
|
|||
"github.com/httprunner/httprunner/v5/uixt/ai"
|
||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) error {
|
||||
options := option.NewActionOptions(opts...)
|
||||
var attempt int
|
||||
for {
|
||||
attempt++
|
||||
log.Info().Int("attempt", attempt).Msg("planning attempt")
|
||||
|
||||
// plan next action
|
||||
result, err := dExt.PlanNextAction(text, opts...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get next action from planner")
|
||||
}
|
||||
|
||||
// do actions
|
||||
for _, action := range result.NextActions {
|
||||
switch action.ActionType {
|
||||
case ai.ActionTypeClick:
|
||||
point := action.ActionInputs["startBox"].([]float64)
|
||||
if err := dExt.TapAbsXY(point[0], point[1], opts...); err != nil {
|
||||
return err
|
||||
}
|
||||
case ai.ActionTypeFinished:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes {
|
||||
return errors.New("reached max retry times")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
|
||||
if dExt.LLMService == nil {
|
||||
return nil, errors.New("LLM service is not initialized")
|
||||
|
@ -25,15 +58,13 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
|
|||
|
||||
planningOpts := &ai.PlanningOptions{
|
||||
UserInstruction: text,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: screenShotBase64,
|
||||
},
|
||||
Message: &schema.Message{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: screenShotBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
_ "image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
|
Loading…
Reference in New Issue