feat: appendConversationHistory for ai planner

This commit is contained in:
lilong.129 2025-03-22 00:06:30 +08:00
parent bbc05513f9
commit 8a3b6b5c4c
6 changed files with 196 additions and 104 deletions

View File

@ -1 +1 @@
v5.0.0-beta-2503202053
v5.0.0-beta-2503220006

View File

@ -21,9 +21,9 @@ func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
// PlanningOptions represents the input options for planning
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"`
ConversationHistory []*schema.Message `json:"conversation_history"`
Size types.Size `json:"size"`
UserInstruction string `json:"user_instruction"` // append to system prompt
Message *schema.Message `json:"message"`
Size types.Size `json:"size"`
}
// PlanningResult represents the result of planning

View File

@ -47,25 +47,41 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
}
type Planner struct {
ctx context.Context
model model.ChatModel
parser *ActionParser
ctx context.Context
model model.ChatModel
parser *ActionParser
history []*schema.Message // conversation history
}
// Call performs UI planning using Vision Language Model
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
}
// call VLM service
resp, err := p.callVLMService(opts)
if err != nil {
return nil, errors.Wrap(err, "call VLM service failed")
// prepare prompt
if len(p.history) == 0 {
// add system message
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
p.history = []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
}
}
// append user image message
p.appendConversationHistory(opts.Message)
// call model service, generate response
logRequest(p.history)
log.Info().Msg("calling model service...")
resp, err := p.model.Generate(p.ctx, p.history)
if err != nil {
return nil, fmt.Errorf("request model service failed: %w", err)
}
logResponse(resp)
// parse result
result, err := p.parseResult(resp, opts.Size)
@ -73,10 +89,12 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
return nil, errors.Wrap(err, "parse result failed")
}
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
// append assistant message
p.appendConversationHistory(&schema.Message{
Role: schema.Assistant,
Content: result.ActionSummary,
})
return result, nil
}
@ -85,57 +103,107 @@ func validateInput(opts *PlanningOptions) error {
return ErrEmptyInstruction
}
if len(opts.ConversationHistory) == 0 {
if opts.Message == nil {
return ErrNoConversationHistory
}
// ensure at least one image URL
hasImageURL := false
for _, msg := range opts.ConversationHistory {
if msg.Role == "user" {
// check MultiContent
if len(msg.MultiContent) > 0 {
for _, content := range msg.MultiContent {
if content.Type == "image_url" && content.ImageURL != nil {
hasImageURL = true
break
}
if opts.Message.Role == schema.User {
// check MultiContent
if len(opts.Message.MultiContent) > 0 {
for _, content := range opts.Message.MultiContent {
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
return ErrInvalidImageData
}
}
}
if hasImageURL {
break
}
}
if !hasImageURL {
return ErrInvalidImageData
}
return nil
}
// callVLMService makes the actual call to the VLM service
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
log.Info().Msg("calling VLM service...")
// prepare prompt
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
messages := []*schema.Message{
{
Role: schema.System,
Content: systemPrompt,
},
func logRequest(messages []*schema.Message) {
msgs := make([]*schema.Message, 0, len(messages))
for _, message := range messages {
msg := &schema.Message{
Role: message.Role,
}
if message.Content != "" {
msg.Content = message.Content
} else if len(message.MultiContent) > 0 {
for _, mc := range message.MultiContent {
switch mc.Type {
case schema.ChatMessagePartTypeImageURL:
// Create a copy of the ImageURL to avoid modifying the original message
imageURLCopy := *mc.ImageURL
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
imageURLCopy.URL = "<data:image/base64...>"
}
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
Type: mc.Type,
ImageURL: &imageURLCopy,
})
}
}
}
msgs = append(msgs, msg)
}
messages = append(messages, opts.ConversationHistory...)
log.Debug().Interface("messages", msgs).Msg("log request messages")
}
// generate response
resp, err := p.model.Generate(p.ctx, messages)
if err != nil {
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
func logResponse(resp *schema.Message) {
log.Info().Str("role", string(resp.Role)).
Str("content", resp.Content).Msg("log response message")
}
// appendConversationHistory adds a message to the conversation history
func (p *Planner) appendConversationHistory(msg *schema.Message) {
// for user image message:
// - keep at most 4 user image messages
// - delete the oldest user image message when the limit is reached
if msg.Role == schema.User {
// get all existing user messages
userImgCount := 0
firstUserImgIndex := -1
// calculate the number of user messages and find the index of the first user message
for i, item := range p.history {
if item.Role == schema.User {
userImgCount++
if firstUserImgIndex == -1 {
firstUserImgIndex = i
}
}
}
// if there are already 4 user messages, delete the first one before adding the new message
if userImgCount >= 4 && firstUserImgIndex >= 0 {
// delete the first user message
p.history = append(
p.history[:firstUserImgIndex],
p.history[firstUserImgIndex+1:]...,
)
}
// add the new user message to the history
p.history = append(p.history, msg)
}
// for assistant message:
// - keep at most the last 10 assistant messages
if msg.Role == schema.Assistant {
// add the new assistant message to the history
p.history = append(p.history, msg)
// if there are more than 10 assistant messages, remove the oldest ones
assistantMsgCount := 0
for i := len(p.history) - 1; i >= 0; i-- {
if p.history[i].Role == schema.Assistant {
assistantMsgCount++
if assistantMsgCount > 10 {
p.history = append(p.history[:i], p.history[i+1:]...)
}
}
}
}
log.Info().Str("content", resp.Content).Msg("get VLM response")
return resp, nil
}
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
@ -151,6 +219,10 @@ func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningRe
return nil, errors.Wrap(err, "process VLM response failed")
}
log.Info().
Interface("summary", result.ActionSummary).
Interface("actions", result.NextActions).
Msg("get VLM planning result")
return result, nil
}

View File

@ -12,7 +12,7 @@ import (
)
func TestVLMPlanning(t *testing.T) {
imageBase64, size, err := loadImage("testdata/llk_3.png")
imageBase64, size, err := loadImage("testdata/llk_1.png")
require.NoError(t, err)
userInstruction := `连连看是一款经典的益智消除类小游戏通常以图案或图标为主要元素以下是连连看的基本规则说明
@ -33,15 +33,13 @@ func TestVLMPlanning(t *testing.T) {
opts := &PlanningOptions{
UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
Message: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
@ -108,15 +106,13 @@ func TestXHSPlanning(t *testing.T) {
opts := &PlanningOptions{
UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
Message: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
@ -182,15 +178,13 @@ func TestValidateInput(t *testing.T) {
name: "valid input",
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
Message: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
@ -203,11 +197,9 @@ func TestValidateInput(t *testing.T) {
name: "empty instruction",
opts: &PlanningOptions{
UserInstruction: "",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
Content: "",
},
Message: &schema.Message{
Role: schema.User,
Content: "",
},
Size: size,
},
@ -216,8 +208,8 @@ func TestValidateInput(t *testing.T) {
{
name: "empty conversation history",
opts: &PlanningOptions{
UserInstruction: "点击立即卸载按钮",
ConversationHistory: []*schema.Message{},
UserInstruction: "点击立即卸载按钮",
Message: &schema.Message{},
},
wantErr: ErrNoConversationHistory,
},
@ -225,11 +217,9 @@ func TestValidateInput(t *testing.T) {
name: "invalid image data",
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
Content: "no image",
},
Message: &schema.Message{
Role: schema.User,
Content: "no image",
},
Size: size,
},

View File

@ -6,8 +6,41 @@ import (
"github.com/httprunner/httprunner/v5/uixt/ai"
"github.com/httprunner/httprunner/v5/uixt/option"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) error {
options := option.NewActionOptions(opts...)
var attempt int
for {
attempt++
log.Info().Int("attempt", attempt).Msg("planning attempt")
// plan next action
result, err := dExt.PlanNextAction(text, opts...)
if err != nil {
return errors.Wrap(err, "failed to get next action from planner")
}
// do actions
for _, action := range result.NextActions {
switch action.ActionType {
case ai.ActionTypeClick:
point := action.ActionInputs["startBox"].([]float64)
if err := dExt.TapAbsXY(point[0], point[1], opts...); err != nil {
return err
}
case ai.ActionTypeFinished:
return nil
}
}
if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes {
return errors.New("reached max retry times")
}
}
}
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
if dExt.LLMService == nil {
return nil, errors.New("LLM service is not initialized")
@ -25,15 +58,13 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
planningOpts := &ai.PlanningOptions{
UserInstruction: text,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: screenShotBase64,
},
Message: &schema.Message{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: screenShotBase64,
},
},
},

View File

@ -8,7 +8,6 @@ import (
"image/gif"
"image/jpeg"
"image/png"
_ "image/png"
"os"
"path/filepath"
"strings"