feat: appendConversationHistory for ai planner

This commit is contained in:
lilong.129 2025-03-22 00:06:30 +08:00
parent bbc05513f9
commit 8a3b6b5c4c
6 changed files with 196 additions and 104 deletions

View File

@ -1 +1 @@
v5.0.0-beta-2503202053 v5.0.0-beta-2503220006

View File

@ -21,9 +21,9 @@ func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
// PlanningOptions represents the input options for planning // PlanningOptions represents the input options for planning
type PlanningOptions struct { type PlanningOptions struct {
UserInstruction string `json:"user_instruction"` UserInstruction string `json:"user_instruction"` // append to system prompt
ConversationHistory []*schema.Message `json:"conversation_history"` Message *schema.Message `json:"message"`
Size types.Size `json:"size"` Size types.Size `json:"size"`
} }
// PlanningResult represents the result of planning // PlanningResult represents the result of planning

View File

@ -47,25 +47,41 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
} }
type Planner struct { type Planner struct {
ctx context.Context ctx context.Context
model model.ChatModel model model.ChatModel
parser *ActionParser parser *ActionParser
history []*schema.Message // conversation history
} }
// Call performs UI planning using Vision Language Model // Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) { func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
// validate input parameters // validate input parameters
if err := validateInput(opts); err != nil { if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed") return nil, errors.Wrap(err, "validate input parameters failed")
} }
// call VLM service // prepare prompt
resp, err := p.callVLMService(opts) if len(p.history) == 0 {
if err != nil { // add system message
return nil, errors.Wrap(err, "call VLM service failed") systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
}
} }
// append user image message
p.appendConversationHistory(opts.Message)
// call model service, generate response
logRequest(p.history)
log.Info().Msg("calling model service...")
resp, err := p.model.Generate(p.ctx, p.history)
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
}
logResponse(resp)
// parse result // parse result
result, err := p.parseResult(resp, opts.Size) result, err := p.parseResult(resp, opts.Size)
@ -73,10 +89,12 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
return nil, errors.Wrap(err, "parse result failed") return nil, errors.Wrap(err, "parse result failed")
} }
log.Info(). // append assistant message
Interface("summary", result.ActionSummary). p.appendConversationHistory(&schema.Message{
Interface("actions", result.NextActions). Role: schema.Assistant,
Msg("get VLM planning result") Content: result.ActionSummary,
})
return result, nil return result, nil
} }
@ -85,57 +103,107 @@ func validateInput(opts *PlanningOptions) error {
return ErrEmptyInstruction return ErrEmptyInstruction
} }
if len(opts.ConversationHistory) == 0 { if opts.Message == nil {
return ErrNoConversationHistory return ErrNoConversationHistory
} }
// ensure at least one image URL if opts.Message.Role == schema.User {
hasImageURL := false // check MultiContent
for _, msg := range opts.ConversationHistory { if len(opts.Message.MultiContent) > 0 {
if msg.Role == "user" { for _, content := range opts.Message.MultiContent {
// check MultiContent if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
if len(msg.MultiContent) > 0 { return ErrInvalidImageData
for _, content := range msg.MultiContent {
if content.Type == "image_url" && content.ImageURL != nil {
hasImageURL = true
break
}
} }
} }
} }
if hasImageURL {
break
}
}
if !hasImageURL {
return ErrInvalidImageData
} }
return nil return nil
} }
// callVLMService makes the actual call to the VLM service func logRequest(messages []*schema.Message) {
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) { msgs := make([]*schema.Message, 0, len(messages))
log.Info().Msg("calling VLM service...") for _, message := range messages {
msg := &schema.Message{
// prepare prompt Role: message.Role,
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction }
messages := []*schema.Message{ if message.Content != "" {
{ msg.Content = message.Content
Role: schema.System, } else if len(message.MultiContent) > 0 {
Content: systemPrompt, for _, mc := range message.MultiContent {
}, switch mc.Type {
case schema.ChatMessagePartTypeImageURL:
// Create a copy of the ImageURL to avoid modifying the original message
imageURLCopy := *mc.ImageURL
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
imageURLCopy.URL = "<data:image/base64...>"
}
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
Type: mc.Type,
ImageURL: &imageURLCopy,
})
}
}
}
msgs = append(msgs, msg)
} }
messages = append(messages, opts.ConversationHistory...) log.Debug().Interface("messages", msgs).Msg("log request messages")
}
// generate response func logResponse(resp *schema.Message) {
resp, err := p.model.Generate(p.ctx, messages) log.Info().Str("role", string(resp.Role)).
if err != nil { Str("content", resp.Content).Msg("log response message")
return nil, fmt.Errorf("OpenAI API request failed: %w", err) }
// appendConversationHistory adds a message to the conversation history
func (p *Planner) appendConversationHistory(msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
if msg.Role == schema.User {
// get all existing user messages
userImgCount := 0
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range p.history {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
firstUserImgIndex = i
}
}
}
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
p.history = append(
p.history[:firstUserImgIndex],
p.history[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
p.history = append(p.history, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
p.history = append(p.history, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(p.history) - 1; i >= 0; i-- {
if p.history[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
p.history = append(p.history[:i], p.history[i+1:]...)
}
}
}
} }
log.Info().Str("content", resp.Content).Msg("get VLM response")
return resp, nil
} }
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) { func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
@ -151,6 +219,10 @@ func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningRe
return nil, errors.Wrap(err, "process VLM response failed") return nil, errors.Wrap(err, "process VLM response failed")
} }
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
return result, nil return result, nil
} }

View File

@ -12,7 +12,7 @@ import (
) )
func TestVLMPlanning(t *testing.T) { func TestVLMPlanning(t *testing.T) {
imageBase64, size, err := loadImage("testdata/llk_3.png") imageBase64, size, err := loadImage("testdata/llk_1.png")
require.NoError(t, err) require.NoError(t, err)
userInstruction := `连连看是一款经典的益智消除类小游戏通常以图案或图标为主要元素以下是连连看的基本规则说明 userInstruction := `连连看是一款经典的益智消除类小游戏通常以图案或图标为主要元素以下是连连看的基本规则说明
@ -33,15 +33,13 @@ func TestVLMPlanning(t *testing.T) {
opts := &PlanningOptions{ opts := &PlanningOptions{
UserInstruction: userInstruction, UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{ Message: &schema.Message{
{ Role: schema.User,
Role: schema.User, MultiContent: []schema.ChatMessagePart{
MultiContent: []schema.ChatMessagePart{ {
{ Type: schema.ChatMessagePartTypeImageURL,
Type: "image_url", ImageURL: &schema.ChatMessageImageURL{
ImageURL: &schema.ChatMessageImageURL{ URL: imageBase64,
URL: imageBase64,
},
}, },
}, },
}, },
@ -108,15 +106,13 @@ func TestXHSPlanning(t *testing.T) {
opts := &PlanningOptions{ opts := &PlanningOptions{
UserInstruction: userInstruction, UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{ Message: &schema.Message{
{ Role: schema.User,
Role: schema.User, MultiContent: []schema.ChatMessagePart{
MultiContent: []schema.ChatMessagePart{ {
{ Type: schema.ChatMessagePartTypeImageURL,
Type: "image_url", ImageURL: &schema.ChatMessageImageURL{
ImageURL: &schema.ChatMessageImageURL{ URL: imageBase64,
URL: imageBase64,
},
}, },
}, },
}, },
@ -182,15 +178,13 @@ func TestValidateInput(t *testing.T) {
name: "valid input", name: "valid input",
opts: &PlanningOptions{ opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮", UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{ Message: &schema.Message{
{ Role: schema.User,
Role: schema.User, MultiContent: []schema.ChatMessagePart{
MultiContent: []schema.ChatMessagePart{ {
{ Type: "image_url",
Type: "image_url", ImageURL: &schema.ChatMessageImageURL{
ImageURL: &schema.ChatMessageImageURL{ URL: imageBase64,
URL: imageBase64,
},
}, },
}, },
}, },
@ -203,11 +197,9 @@ func TestValidateInput(t *testing.T) {
name: "empty instruction", name: "empty instruction",
opts: &PlanningOptions{ opts: &PlanningOptions{
UserInstruction: "", UserInstruction: "",
ConversationHistory: []*schema.Message{ Message: &schema.Message{
{ Role: schema.User,
Role: schema.User, Content: "",
Content: "",
},
}, },
Size: size, Size: size,
}, },
@ -216,8 +208,8 @@ func TestValidateInput(t *testing.T) {
{ {
name: "empty conversation history", name: "empty conversation history",
opts: &PlanningOptions{ opts: &PlanningOptions{
UserInstruction: "点击立即卸载按钮", UserInstruction: "点击立即卸载按钮",
ConversationHistory: []*schema.Message{}, Message: &schema.Message{},
}, },
wantErr: ErrNoConversationHistory, wantErr: ErrNoConversationHistory,
}, },
@ -225,11 +217,9 @@ func TestValidateInput(t *testing.T) {
name: "invalid image data", name: "invalid image data",
opts: &PlanningOptions{ opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮", UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{ Message: &schema.Message{
{ Role: schema.User,
Role: schema.User, Content: "no image",
Content: "no image",
},
}, },
Size: size, Size: size,
}, },

View File

@ -6,8 +6,41 @@ import (
"github.com/httprunner/httprunner/v5/uixt/ai" "github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option" "github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog/log"
) )
func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) error {
options := option.NewActionOptions(opts...)
var attempt int
for {
attempt++
log.Info().Int("attempt", attempt).Msg("planning attempt")
// plan next action
result, err := dExt.PlanNextAction(text, opts...)
if err != nil {
return errors.Wrap(err, "failed to get next action from planner")
}
// do actions
for _, action := range result.NextActions {
switch action.ActionType {
case ai.ActionTypeClick:
point := action.ActionInputs["startBox"].([]float64)
if err := dExt.TapAbsXY(point[0], point[1], opts...); err != nil {
return err
}
case ai.ActionTypeFinished:
return nil
}
}
if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes {
return errors.New("reached max retry times")
}
}
}
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) { func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
if dExt.LLMService == nil { if dExt.LLMService == nil {
return nil, errors.New("LLM service is not initialized") return nil, errors.New("LLM service is not initialized")
@ -25,15 +58,13 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
planningOpts := &ai.PlanningOptions{ planningOpts := &ai.PlanningOptions{
UserInstruction: text, UserInstruction: text,
ConversationHistory: []*schema.Message{ Message: &schema.Message{
{ Role: schema.User,
Role: schema.User, MultiContent: []schema.ChatMessagePart{
MultiContent: []schema.ChatMessagePart{ {
{ Type: schema.ChatMessagePartTypeImageURL,
Type: "image_url", ImageURL: &schema.ChatMessageImageURL{
ImageURL: &schema.ChatMessageImageURL{ URL: screenShotBase64,
URL: screenShotBase64,
},
}, },
}, },
}, },

View File

@ -8,7 +8,6 @@ import (
"image/gif" "image/gif"
"image/jpeg" "image/jpeg"
"image/png" "image/png"
_ "image/png"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"