feat: appendConversationHistory for ai planner
This commit is contained in:
parent
bbc05513f9
commit
8a3b6b5c4c
|
@ -1 +1 @@
|
||||||
v5.0.0-beta-2503202053
|
v5.0.0-beta-2503220006
|
||||||
|
|
|
@ -21,9 +21,9 @@ func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||||
|
|
||||||
// PlanningOptions represents the input options for planning
|
// PlanningOptions represents the input options for planning
|
||||||
type PlanningOptions struct {
|
type PlanningOptions struct {
|
||||||
UserInstruction string `json:"user_instruction"`
|
UserInstruction string `json:"user_instruction"` // append to system prompt
|
||||||
ConversationHistory []*schema.Message `json:"conversation_history"`
|
Message *schema.Message `json:"message"`
|
||||||
Size types.Size `json:"size"`
|
Size types.Size `json:"size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PlanningResult represents the result of planning
|
// PlanningResult represents the result of planning
|
||||||
|
|
|
@ -47,25 +47,41 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Planner struct {
|
type Planner struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
model model.ChatModel
|
model model.ChatModel
|
||||||
parser *ActionParser
|
parser *ActionParser
|
||||||
|
history []*schema.Message // conversation history
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call performs UI planning using Vision Language Model
|
// Call performs UI planning using Vision Language Model
|
||||||
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||||
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
|
||||||
|
|
||||||
// validate input parameters
|
// validate input parameters
|
||||||
if err := validateInput(opts); err != nil {
|
if err := validateInput(opts); err != nil {
|
||||||
return nil, errors.Wrap(err, "validate input parameters failed")
|
return nil, errors.Wrap(err, "validate input parameters failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// call VLM service
|
// prepare prompt
|
||||||
resp, err := p.callVLMService(opts)
|
if len(p.history) == 0 {
|
||||||
if err != nil {
|
// add system message
|
||||||
return nil, errors.Wrap(err, "call VLM service failed")
|
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
||||||
|
p.history = []*schema.Message{
|
||||||
|
{
|
||||||
|
Role: schema.System,
|
||||||
|
Content: systemPrompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// append user image message
|
||||||
|
p.appendConversationHistory(opts.Message)
|
||||||
|
|
||||||
|
// call model service, generate response
|
||||||
|
logRequest(p.history)
|
||||||
|
log.Info().Msg("calling model service...")
|
||||||
|
resp, err := p.model.Generate(p.ctx, p.history)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request model service failed: %w", err)
|
||||||
|
}
|
||||||
|
logResponse(resp)
|
||||||
|
|
||||||
// parse result
|
// parse result
|
||||||
result, err := p.parseResult(resp, opts.Size)
|
result, err := p.parseResult(resp, opts.Size)
|
||||||
|
@ -73,10 +89,12 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
||||||
return nil, errors.Wrap(err, "parse result failed")
|
return nil, errors.Wrap(err, "parse result failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
// append assistant message
|
||||||
Interface("summary", result.ActionSummary).
|
p.appendConversationHistory(&schema.Message{
|
||||||
Interface("actions", result.NextActions).
|
Role: schema.Assistant,
|
||||||
Msg("get VLM planning result")
|
Content: result.ActionSummary,
|
||||||
|
})
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,57 +103,107 @@ func validateInput(opts *PlanningOptions) error {
|
||||||
return ErrEmptyInstruction
|
return ErrEmptyInstruction
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.ConversationHistory) == 0 {
|
if opts.Message == nil {
|
||||||
return ErrNoConversationHistory
|
return ErrNoConversationHistory
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure at least one image URL
|
if opts.Message.Role == schema.User {
|
||||||
hasImageURL := false
|
// check MultiContent
|
||||||
for _, msg := range opts.ConversationHistory {
|
if len(opts.Message.MultiContent) > 0 {
|
||||||
if msg.Role == "user" {
|
for _, content := range opts.Message.MultiContent {
|
||||||
// check MultiContent
|
if content.Type == schema.ChatMessagePartTypeImageURL && content.ImageURL == nil {
|
||||||
if len(msg.MultiContent) > 0 {
|
return ErrInvalidImageData
|
||||||
for _, content := range msg.MultiContent {
|
|
||||||
if content.Type == "image_url" && content.ImageURL != nil {
|
|
||||||
hasImageURL = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasImageURL {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasImageURL {
|
|
||||||
return ErrInvalidImageData
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// callVLMService makes the actual call to the VLM service
|
func logRequest(messages []*schema.Message) {
|
||||||
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
|
msgs := make([]*schema.Message, 0, len(messages))
|
||||||
log.Info().Msg("calling VLM service...")
|
for _, message := range messages {
|
||||||
|
msg := &schema.Message{
|
||||||
// prepare prompt
|
Role: message.Role,
|
||||||
systemPrompt := uiTarsPlanningPrompt + opts.UserInstruction
|
}
|
||||||
messages := []*schema.Message{
|
if message.Content != "" {
|
||||||
{
|
msg.Content = message.Content
|
||||||
Role: schema.System,
|
} else if len(message.MultiContent) > 0 {
|
||||||
Content: systemPrompt,
|
for _, mc := range message.MultiContent {
|
||||||
},
|
switch mc.Type {
|
||||||
|
case schema.ChatMessagePartTypeImageURL:
|
||||||
|
// Create a copy of the ImageURL to avoid modifying the original message
|
||||||
|
imageURLCopy := *mc.ImageURL
|
||||||
|
if strings.HasPrefix(imageURLCopy.URL, "data:image/") {
|
||||||
|
imageURLCopy.URL = "<data:image/base64...>"
|
||||||
|
}
|
||||||
|
msg.MultiContent = append(msg.MultiContent, schema.ChatMessagePart{
|
||||||
|
Type: mc.Type,
|
||||||
|
ImageURL: &imageURLCopy,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msgs = append(msgs, msg)
|
||||||
}
|
}
|
||||||
messages = append(messages, opts.ConversationHistory...)
|
log.Debug().Interface("messages", msgs).Msg("log request messages")
|
||||||
|
}
|
||||||
|
|
||||||
// generate response
|
func logResponse(resp *schema.Message) {
|
||||||
resp, err := p.model.Generate(p.ctx, messages)
|
log.Info().Str("role", string(resp.Role)).
|
||||||
if err != nil {
|
Str("content", resp.Content).Msg("log response message")
|
||||||
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
}
|
||||||
|
|
||||||
|
// appendConversationHistory adds a message to the conversation history
|
||||||
|
func (p *Planner) appendConversationHistory(msg *schema.Message) {
|
||||||
|
// for user image message:
|
||||||
|
// - keep at most 4 user image messages
|
||||||
|
// - delete the oldest user image message when the limit is reached
|
||||||
|
if msg.Role == schema.User {
|
||||||
|
// get all existing user messages
|
||||||
|
userImgCount := 0
|
||||||
|
firstUserImgIndex := -1
|
||||||
|
|
||||||
|
// calculate the number of user messages and find the index of the first user message
|
||||||
|
for i, item := range p.history {
|
||||||
|
if item.Role == schema.User {
|
||||||
|
userImgCount++
|
||||||
|
if firstUserImgIndex == -1 {
|
||||||
|
firstUserImgIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are already 4 user messages, delete the first one before adding the new message
|
||||||
|
if userImgCount >= 4 && firstUserImgIndex >= 0 {
|
||||||
|
// delete the first user message
|
||||||
|
p.history = append(
|
||||||
|
p.history[:firstUserImgIndex],
|
||||||
|
p.history[firstUserImgIndex+1:]...,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// add the new user message to the history
|
||||||
|
p.history = append(p.history, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for assistant message:
|
||||||
|
// - keep at most the last 10 assistant messages
|
||||||
|
if msg.Role == schema.Assistant {
|
||||||
|
// add the new assistant message to the history
|
||||||
|
p.history = append(p.history, msg)
|
||||||
|
|
||||||
|
// if there are more than 10 assistant messages, remove the oldest ones
|
||||||
|
assistantMsgCount := 0
|
||||||
|
for i := len(p.history) - 1; i >= 0; i-- {
|
||||||
|
if p.history[i].Role == schema.Assistant {
|
||||||
|
assistantMsgCount++
|
||||||
|
if assistantMsgCount > 10 {
|
||||||
|
p.history = append(p.history[:i], p.history[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
log.Info().Str("content", resp.Content).Msg("get VLM response")
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
|
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
|
||||||
|
@ -151,6 +219,10 @@ func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningRe
|
||||||
return nil, errors.Wrap(err, "process VLM response failed")
|
return nil, errors.Wrap(err, "process VLM response failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Interface("summary", result.ActionSummary).
|
||||||
|
Interface("actions", result.NextActions).
|
||||||
|
Msg("get VLM planning result")
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVLMPlanning(t *testing.T) {
|
func TestVLMPlanning(t *testing.T) {
|
||||||
imageBase64, size, err := loadImage("testdata/llk_3.png")
|
imageBase64, size, err := loadImage("testdata/llk_1.png")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
|
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
|
||||||
|
@ -33,15 +33,13 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
|
|
||||||
opts := &PlanningOptions{
|
opts := &PlanningOptions{
|
||||||
UserInstruction: userInstruction,
|
UserInstruction: userInstruction,
|
||||||
ConversationHistory: []*schema.Message{
|
Message: &schema.Message{
|
||||||
{
|
Role: schema.User,
|
||||||
Role: schema.User,
|
MultiContent: []schema.ChatMessagePart{
|
||||||
MultiContent: []schema.ChatMessagePart{
|
{
|
||||||
{
|
Type: schema.ChatMessagePartTypeImageURL,
|
||||||
Type: "image_url",
|
ImageURL: &schema.ChatMessageImageURL{
|
||||||
ImageURL: &schema.ChatMessageImageURL{
|
URL: imageBase64,
|
||||||
URL: imageBase64,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -108,15 +106,13 @@ func TestXHSPlanning(t *testing.T) {
|
||||||
|
|
||||||
opts := &PlanningOptions{
|
opts := &PlanningOptions{
|
||||||
UserInstruction: userInstruction,
|
UserInstruction: userInstruction,
|
||||||
ConversationHistory: []*schema.Message{
|
Message: &schema.Message{
|
||||||
{
|
Role: schema.User,
|
||||||
Role: schema.User,
|
MultiContent: []schema.ChatMessagePart{
|
||||||
MultiContent: []schema.ChatMessagePart{
|
{
|
||||||
{
|
Type: schema.ChatMessagePartTypeImageURL,
|
||||||
Type: "image_url",
|
ImageURL: &schema.ChatMessageImageURL{
|
||||||
ImageURL: &schema.ChatMessageImageURL{
|
URL: imageBase64,
|
||||||
URL: imageBase64,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -182,15 +178,13 @@ func TestValidateInput(t *testing.T) {
|
||||||
name: "valid input",
|
name: "valid input",
|
||||||
opts: &PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "点击继续使用按钮",
|
UserInstruction: "点击继续使用按钮",
|
||||||
ConversationHistory: []*schema.Message{
|
Message: &schema.Message{
|
||||||
{
|
Role: schema.User,
|
||||||
Role: schema.User,
|
MultiContent: []schema.ChatMessagePart{
|
||||||
MultiContent: []schema.ChatMessagePart{
|
{
|
||||||
{
|
Type: "image_url",
|
||||||
Type: "image_url",
|
ImageURL: &schema.ChatMessageImageURL{
|
||||||
ImageURL: &schema.ChatMessageImageURL{
|
URL: imageBase64,
|
||||||
URL: imageBase64,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -203,11 +197,9 @@ func TestValidateInput(t *testing.T) {
|
||||||
name: "empty instruction",
|
name: "empty instruction",
|
||||||
opts: &PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "",
|
UserInstruction: "",
|
||||||
ConversationHistory: []*schema.Message{
|
Message: &schema.Message{
|
||||||
{
|
Role: schema.User,
|
||||||
Role: schema.User,
|
Content: "",
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
Size: size,
|
Size: size,
|
||||||
},
|
},
|
||||||
|
@ -216,8 +208,8 @@ func TestValidateInput(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "empty conversation history",
|
name: "empty conversation history",
|
||||||
opts: &PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "点击立即卸载按钮",
|
UserInstruction: "点击立即卸载按钮",
|
||||||
ConversationHistory: []*schema.Message{},
|
Message: &schema.Message{},
|
||||||
},
|
},
|
||||||
wantErr: ErrNoConversationHistory,
|
wantErr: ErrNoConversationHistory,
|
||||||
},
|
},
|
||||||
|
@ -225,11 +217,9 @@ func TestValidateInput(t *testing.T) {
|
||||||
name: "invalid image data",
|
name: "invalid image data",
|
||||||
opts: &PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "点击继续使用按钮",
|
UserInstruction: "点击继续使用按钮",
|
||||||
ConversationHistory: []*schema.Message{
|
Message: &schema.Message{
|
||||||
{
|
Role: schema.User,
|
||||||
Role: schema.User,
|
Content: "no image",
|
||||||
Content: "no image",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
Size: size,
|
Size: size,
|
||||||
},
|
},
|
||||||
|
|
|
@ -6,8 +6,41 @@ import (
|
||||||
"github.com/httprunner/httprunner/v5/uixt/ai"
|
"github.com/httprunner/httprunner/v5/uixt/ai"
|
||||||
"github.com/httprunner/httprunner/v5/uixt/option"
|
"github.com/httprunner/httprunner/v5/uixt/option"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (dExt *XTDriver) StartToGoal(text string, opts ...option.ActionOption) error {
|
||||||
|
options := option.NewActionOptions(opts...)
|
||||||
|
var attempt int
|
||||||
|
for {
|
||||||
|
attempt++
|
||||||
|
log.Info().Int("attempt", attempt).Msg("planning attempt")
|
||||||
|
|
||||||
|
// plan next action
|
||||||
|
result, err := dExt.PlanNextAction(text, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to get next action from planner")
|
||||||
|
}
|
||||||
|
|
||||||
|
// do actions
|
||||||
|
for _, action := range result.NextActions {
|
||||||
|
switch action.ActionType {
|
||||||
|
case ai.ActionTypeClick:
|
||||||
|
point := action.ActionInputs["startBox"].([]float64)
|
||||||
|
if err := dExt.TapAbsXY(point[0], point[1], opts...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case ai.ActionTypeFinished:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.MaxRetryTimes > 1 && attempt >= options.MaxRetryTimes {
|
||||||
|
return errors.New("reached max retry times")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
|
func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (*ai.PlanningResult, error) {
|
||||||
if dExt.LLMService == nil {
|
if dExt.LLMService == nil {
|
||||||
return nil, errors.New("LLM service is not initialized")
|
return nil, errors.New("LLM service is not initialized")
|
||||||
|
@ -25,15 +58,13 @@ func (dExt *XTDriver) PlanNextAction(text string, opts ...option.ActionOption) (
|
||||||
|
|
||||||
planningOpts := &ai.PlanningOptions{
|
planningOpts := &ai.PlanningOptions{
|
||||||
UserInstruction: text,
|
UserInstruction: text,
|
||||||
ConversationHistory: []*schema.Message{
|
Message: &schema.Message{
|
||||||
{
|
Role: schema.User,
|
||||||
Role: schema.User,
|
MultiContent: []schema.ChatMessagePart{
|
||||||
MultiContent: []schema.ChatMessagePart{
|
{
|
||||||
{
|
Type: schema.ChatMessagePartTypeImageURL,
|
||||||
Type: "image_url",
|
ImageURL: &schema.ChatMessageImageURL{
|
||||||
ImageURL: &schema.ChatMessageImageURL{
|
URL: screenShotBase64,
|
||||||
URL: screenShotBase64,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"image/gif"
|
"image/gif"
|
||||||
"image/jpeg"
|
"image/jpeg"
|
||||||
"image/png"
|
"image/png"
|
||||||
_ "image/png"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
Loading…
Reference in New Issue