refactor: ui-tars planner

This commit is contained in:
lilong.129 2025-03-19 17:48:13 +08:00
parent 6c74727c44
commit 7da305f577
8 changed files with 201 additions and 191 deletions

View File

@ -1 +1 @@
v5.0.0-beta-2503190009
v5.0.0-beta-2503191748

View File

@ -3,6 +3,7 @@ package planner
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strconv"
"time"
@ -89,6 +90,20 @@ func GetEnvConfigInInt(key string, defaultValue int) int {
return intValue
}
// CustomTransport is a custom RoundTripper that adds headers to every request
type CustomTransport struct {
Transport http.RoundTripper
Headers map[string]string
}
// RoundTrip executes a single HTTP transaction and adds custom headers
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range c.Headers {
req.Header.Set(key, value)
}
return c.Transport.RoundTrip(req)
}
// GetModelConfig get OpenAI config
func GetModelConfig() (*openai.ChatModelConfig, error) {
envConfig := &OpenAIInitConfig{
@ -104,7 +119,13 @@ func GetModelConfig() (*openai.ChatModelConfig, error) {
}
config := &openai.ChatModelConfig{
HTTPClient: &http.Client{
Timeout: defaultTimeout,
Transport: &CustomTransport{
Transport: http.DefaultTransport,
Headers: envConfig.Headers,
},
},
}
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {

View File

@ -10,40 +10,33 @@ import (
)
// NewActionParser creates a new ActionParser instance
func NewActionParser(prediction string, factor float64) *ActionParser {
func NewActionParser(factor float64) *ActionParser {
return &ActionParser{
Prediction: prediction,
Factor: factor,
}
}
// ActionParser parses VLM responses and converts them to structured actions
type ActionParser struct {
Prediction string
Factor float64
Factor float64 // TODO
}
// Parse parses the prediction text and extracts actions
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
// try parsing JSON format
// try parsing JSON format, from VLM like GPT-4o
var jsonActions []ParsedAction
jsonActions, jsonErr := p.parseJSON(predictionText)
if jsonErr == nil && len(jsonActions) > 0 {
if jsonErr == nil {
return jsonActions, nil
}
// if JSON parsing fails, try parsing Thought/Action format
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
if thoughtErr == nil && len(thoughtActions) > 0 {
if thoughtErr == nil {
return thoughtActions, nil
}
// both parsing methods failed
if jsonErr != nil && thoughtErr != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v; %v", jsonErr, thoughtErr)
}
return nil, fmt.Errorf("no actions returned from VLM")
return nil, fmt.Errorf("no valid actions returned from VLM, jsonErr: %v, thoughtErr: %v", jsonErr, thoughtErr)
}
// parseJSON tries to parse the response as JSON format
@ -92,7 +85,7 @@ func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction
thought = strings.TrimSpace(thoughtMatch[1])
}
// extract Action part
// extract Action part, e.g. "click(start_box='(552,454)')"
actionMatch := actionRegex.FindStringSubmatch(predictionText)
if len(actionMatch) < 2 {
return nil, fmt.Errorf("no action found in the response")
@ -125,6 +118,7 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
"call_user": regexp.MustCompile(`call_user\(\)`),
}
parsedActions := make([]ParsedAction, 0)
for actionType, regex := range actionRegexes {
matches := regex.FindStringSubmatch(actionText)
if len(matches) == 0 {
@ -183,10 +177,13 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
// 这些动作没有额外参数
}
return []ParsedAction{action}, nil
parsedActions = append(parsedActions, action)
}
return nil, fmt.Errorf("unknown action format: %s", actionText)
if len(parsedActions) == 0 {
return nil, fmt.Errorf("no valid actions returned from VLM")
}
return parsedActions, nil
}
// normalizeAction normalizes the coordinates in the action
@ -215,16 +212,14 @@ func (p *ActionParser) normalizeAction(action *ParsedAction) error {
}
// normalizeCoordinates normalizes the coordinates based on the factor
func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
var coords []float64
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
// check empty string
if coordStr == "" {
return "", fmt.Errorf("empty coordinate string")
return nil, fmt.Errorf("empty coordinate string")
}
if !strings.Contains(coordStr, ",") {
return "", fmt.Errorf("invalid coordinate string: %s", coordStr)
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
}
// remove possible brackets and split coordinates
@ -236,15 +231,9 @@ func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
jsonStr = "[" + coordStr + "]"
}
err := json.Unmarshal([]byte(jsonStr), &coords)
err = json.Unmarshal([]byte(jsonStr), &coords)
if err != nil {
return "", fmt.Errorf("failed to parse coordinate string: %w", err)
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
}
normalized, err := json.Marshal(coords)
if err != nil {
return "", fmt.Errorf("failed to marshal normalized coordinates: %w", err)
}
return string(normalized), nil
return coords, nil
}

View File

@ -20,38 +20,11 @@ import (
// Error types
var (
ErrInvalidInput = fmt.Errorf("invalid input parameters")
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
ErrInvalidImageData = fmt.Errorf("invalid image data")
)
const uiTarsPlanningPrompt = `
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
Thought: ...
Action: ...
## Action Space
click(start_box='[x1, y1, x2, y2]')
left_double(start_box='[x1, y1, x2, y2]')
right_single(start_box='[x1, y1, x2, y2]')
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
hotkey(key='')
type(content='') #If you want to submit your input, use "\n" at the end of content.
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use Chinese in Thought part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
## User Instruction
`
func NewPlanner(ctx context.Context) (*Planner, error) {
config, err := GetModelConfig()
if err != nil {
@ -61,36 +34,39 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
}
parser := NewActionParser(1000)
return &Planner{
ctx: ctx,
model: model,
parser: parser,
}, nil
}
type Planner struct {
ctx context.Context
model *openai.ChatModel
parser *ActionParser
}
// Start performs UI planning using Vision Language Model
func (p *Planner) Start(opts PlanningOptions) (*PlanningResult, error) {
func (p *Planner) Start(opts *PlanningOptions) (*PlanningResult, error) {
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
// 1. validate input parameters
// validate input parameters
if err := validateInput(opts); err != nil {
return nil, errors.Wrap(err, "validate input parameters failed")
}
// 2. call VLM service
// call VLM service
resp, err := p.callVLMService(opts)
if err != nil {
return nil, errors.Wrap(err, "call VLM service failed")
}
// 3. process response
result, err := processVLMResponse(resp)
// parse result
result, err := p.parseResult(resp)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
return nil, errors.Wrap(err, "parse result failed")
}
log.Info().
@ -100,7 +76,7 @@ func (p *Planner) Start(opts PlanningOptions) (*PlanningResult, error) {
return result, nil
}
func validateInput(opts PlanningOptions) error {
func validateInput(opts *PlanningOptions) error {
if opts.UserInstruction == "" {
return ErrEmptyInstruction
}
@ -109,10 +85,6 @@ func validateInput(opts PlanningOptions) error {
return ErrNoConversationHistory
}
if opts.Size.Width <= 0 || opts.Size.Height <= 0 {
return ErrInvalidInput
}
// ensure at least one image URL
hasImageURL := false
for _, msg := range opts.ConversationHistory {
@ -133,14 +105,14 @@ func validateInput(opts PlanningOptions) error {
}
if !hasImageURL {
return ErrInvalidInput
return ErrInvalidImageData
}
return nil
}
// callVLMService makes the actual call to the VLM service
func (p *Planner) callVLMService(opts PlanningOptions) (*VLMResponse, error) {
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
log.Info().Msg("calling VLM service...")
// prepare prompt
@ -158,87 +130,77 @@ func (p *Planner) callVLMService(opts PlanningOptions) (*VLMResponse, error) {
if err != nil {
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
}
return resp, nil
}
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
// parse response
content := resp.Content
parser := NewActionParser(content, 1000) // 使用与 TypeScript 版本相同的 factor
actions, err := parser.Parse(content)
actions, err := p.parser.Parse(msg.Content)
if err != nil {
return nil, fmt.Errorf("failed to parse actions: %w", err)
}
return &VLMResponse{
Actions: actions,
}, nil
// process response
result, err := processVLMResponse(actions)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
return result, nil
}
// processVLMResponse processes the VLM response and converts it to PlanningResult
func processVLMResponse(resp *VLMResponse) (*PlanningResult, error) {
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
log.Info().Msg("processing VLM response...")
if resp.Error != "" {
return nil, fmt.Errorf("VLM error: %s", resp.Error)
}
if len(resp.Actions) == 0 {
if len(actions) == 0 {
return nil, fmt.Errorf("no actions returned from VLM")
}
// 验证和后处理每个动作
for i := range resp.Actions {
// 验证动作类型
switch resp.Actions[i].ActionType {
// validate and post-process each action
for i := range actions {
// validate action type
switch actions[i].ActionType {
case "click", "left_double", "right_single":
validateCoordinateAction(&resp.Actions[i], "startBox")
validateCoordinateAction(&actions[i], "startBox")
case "drag":
validateCoordinateAction(&resp.Actions[i], "startBox")
validateCoordinateAction(&resp.Actions[i], "endBox")
validateCoordinateAction(&actions[i], "startBox")
validateCoordinateAction(&actions[i], "endBox")
case "scroll":
validateCoordinateAction(&resp.Actions[i], "startBox")
validateScrollDirection(&resp.Actions[i])
validateCoordinateAction(&actions[i], "startBox")
validateScrollDirection(&actions[i])
case "type":
validateTypeContent(&resp.Actions[i])
validateTypeContent(&actions[i])
case "hotkey":
validateHotkeyAction(&resp.Actions[i])
validateHotkeyAction(&actions[i])
case "wait", "finished", "call_user":
// 这些动作不需要额外参数
// these actions do not need extra parameters
default:
log.Printf("警告: 未知的动作类型: %s, 将尝试继续处理", resp.Actions[i].ActionType)
log.Printf("warning: unknown action type: %s, will try to continue processing", actions[i].ActionType)
}
}
// 提取动作摘要
actionSummary := extractActionSummary(resp.Actions)
// 将ParsedAction转换为接口类型
var actions []interface{}
for _, action := range resp.Actions {
actionMap := map[string]interface{}{
"actionType": action.ActionType,
"actionInputs": action.ActionInputs,
"thought": action.Thought,
}
actions = append(actions, actionMap)
}
// extract action summary
actionSummary := extractActionSummary(actions)
return &PlanningResult{
Actions: actions,
RealActions: resp.Actions,
ActionSummary: actionSummary,
}, nil
}
// extractActionSummary 从动作中提取摘要
// extractActionSummary extracts the summary from the actions
func extractActionSummary(actions []ParsedAction) string {
if len(actions) == 0 {
return ""
}
// 优先使用第一个动作的Thought作为摘要
// use the Thought of the first action as summary
if actions[0].Thought != "" {
return actions[0].Thought
}
// 如果没有Thought则根据动作类型生成摘要
// if no Thought, generate summary from action type
action := actions[0]
switch action.ActionType {
case "click":
@ -274,28 +236,21 @@ func extractActionSummary(actions []ParsedAction) string {
// validateCoordinateAction 验证坐标类动作
func validateCoordinateAction(action *ParsedAction, boxField string) {
if box, ok := action.ActionInputs[boxField]; !ok || box == "" {
// 为空或缺失的坐标设置默认值
action.ActionInputs[boxField] = "[0.5, 0.5]"
log.Printf("警告: %s动作缺少%s参数, 已设置默认值", action.ActionType, boxField)
}
// TODO
}
// validateScrollDirection 验证滚动方向
func validateScrollDirection(action *ParsedAction) {
if direction, ok := action.ActionInputs["direction"].(string); !ok || direction == "" {
// 为空或缺失的方向设置默认值
// default to down
action.ActionInputs["direction"] = "down"
log.Printf("警告: scroll动作缺少direction参数, 已设置默认值")
} else {
// 标准化方向
switch strings.ToLower(direction) {
case "up", "down", "left", "right":
// 保持原样
// keep original direction
default:
// 非标准方向设为默认值
action.ActionInputs["direction"] = "down"
log.Printf("警告: 非标准滚动方向: %s, 已设置为down", direction)
log.Warn().Str("direction", direction).Msg("invalid scroll direction, set to default")
}
}
}
@ -303,9 +258,9 @@ func validateScrollDirection(action *ParsedAction) {
// validateTypeContent 验证输入文本内容
func validateTypeContent(action *ParsedAction) {
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
// 为空或缺失的内容设置默认值
// default to empty string
action.ActionInputs["content"] = ""
log.Printf("警告: type动作缺少content参数, 已设置为空字符串")
log.Warn().Msg("type action missing content parameter, set to default")
}
}

View File

@ -2,7 +2,7 @@ package planner
import (
"context"
"encoding/json"
"os"
"testing"
"github.com/cloudwego/eino/schema"
@ -14,7 +14,6 @@ func TestVLMPlanning(t *testing.T) {
err := loadEnv("testdata/.env")
require.NoError(t, err)
// imageBase64, err := loadImage("testdata/popup_risk_warning.png")
imageBase64, err := loadImage("testdata/llk_1.png")
require.NoError(t, err)
@ -29,12 +28,12 @@ func TestVLMPlanning(t *testing.T) {
5. 得分机制: 每成功连接并消除一对图案玩家会获得相应的分数完成游戏后根据剩余时间和消除效率计算总分
6. 关卡设计: 游戏可能包含多个关卡随着关卡的推进图案的复杂度和数量会增加`
userInstruction += "\n\n请基于以上游戏规则给出下一步可点击的两个图标坐标"
userInstruction += "\n\n请基于以上游戏规则请先点击第一个图标"
planner, err := NewPlanner(context.Background())
require.NoError(t, err)
opts := PlanningOptions{
opts := &PlanningOptions{
UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{
{
@ -49,10 +48,6 @@ func TestVLMPlanning(t *testing.T) {
},
},
},
Size: Size{
Width: 1920,
Height: 1080,
},
}
// 执行规划
@ -61,10 +56,10 @@ func TestVLMPlanning(t *testing.T) {
// 验证结果
require.NoError(t, err)
require.NotNil(t, result)
require.NotEmpty(t, result.RealActions)
require.NotEmpty(t, result.Actions)
// 验证动作
action := result.RealActions[0]
action := result.Actions[0]
assert.NotEmpty(t, action.ActionType)
assert.NotEmpty(t, action.Thought)
@ -75,15 +70,13 @@ func TestVLMPlanning(t *testing.T) {
assert.NotEmpty(t, action.ActionInputs["startBox"])
// 验证坐标格式
var coords []float64
err = json.Unmarshal([]byte(action.ActionInputs["startBox"].(string)), &coords)
require.NoError(t, err)
coords, ok := action.ActionInputs["startBox"].([]float64)
require.True(t, ok)
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
// 验证坐标范围
for _, coord := range coords {
assert.GreaterOrEqual(t, coord, float64(0))
assert.LessOrEqual(t, coord, float64(1920)) // 最大屏幕宽度
}
case "type":
@ -102,18 +95,61 @@ func TestVLMPlanning(t *testing.T) {
}
}
func TestXHSPlanning(t *testing.T) {
err := loadEnv("testdata/.env")
require.NoError(t, err)
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
require.NoError(t, err)
userInstruction := `点击第二个帖子的作者头像`
planner, err := NewPlanner(context.Background())
require.NoError(t, err)
opts := &PlanningOptions{
UserInstruction: userInstruction,
ConversationHistory: []*schema.Message{
{
Role: schema.User,
MultiContent: []schema.ChatMessagePart{
{
Type: "image_url",
ImageURL: &schema.ChatMessageImageURL{
URL: imageBase64,
},
},
},
},
},
}
// 执行规划
result, err := planner.Start(opts)
// 验证结果
require.NoError(t, err)
require.NotNil(t, result)
require.NotEmpty(t, result.Actions)
// 验证动作
action := result.Actions[0]
assert.NotEmpty(t, action.ActionType)
assert.NotEmpty(t, action.Thought)
}
func TestValidateInput(t *testing.T) {
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
tests := []struct {
name string
opts PlanningOptions
opts *PlanningOptions
wantErr error
}{
{
name: "valid input",
opts: PlanningOptions{
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{
{
@ -128,13 +164,12 @@ func TestValidateInput(t *testing.T) {
},
},
},
Size: Size{Width: 100, Height: 100},
},
wantErr: nil,
},
{
name: "empty instruction",
opts: PlanningOptions{
opts: &PlanningOptions{
UserInstruction: "",
ConversationHistory: []*schema.Message{
{
@ -142,32 +177,29 @@ func TestValidateInput(t *testing.T) {
Content: "",
},
},
Size: Size{Width: 100, Height: 100},
},
wantErr: ErrEmptyInstruction,
},
{
name: "empty conversation history",
opts: PlanningOptions{
opts: &PlanningOptions{
UserInstruction: "点击立即卸载按钮",
ConversationHistory: []*schema.Message{},
Size: Size{Width: 100, Height: 100},
},
wantErr: ErrNoConversationHistory,
},
{
name: "invalid size",
opts: PlanningOptions{
UserInstruction: "勾选不再提示选项",
name: "invalid image data",
opts: &PlanningOptions{
UserInstruction: "点击继续使用按钮",
ConversationHistory: []*schema.Message{
{
Role: schema.User,
Content: "",
Content: "no image",
},
},
Size: Size{Width: 0, Height: 0},
},
wantErr: ErrInvalidInput,
wantErr: ErrInvalidImageData,
},
}
@ -176,6 +208,7 @@ func TestValidateInput(t *testing.T) {
err := validateInput(tt.opts)
if tt.wantErr != nil {
assert.Error(t, err)
assert.Equal(t, tt.wantErr, err)
} else {
assert.NoError(t, err)
}
@ -186,40 +219,32 @@ func TestValidateInput(t *testing.T) {
func TestProcessVLMResponse(t *testing.T) {
tests := []struct {
name string
resp *VLMResponse
actions []ParsedAction
wantErr bool
}{
{
name: "valid response",
resp: &VLMResponse{
Actions: []ParsedAction{
actions: []ParsedAction{
{
ActionType: "click",
ActionInputs: map[string]interface{}{
"startBox": "[0.5, 0.5]",
},
"startBox": []float64{0.5, 0.5},
},
Thought: "点击中心位置",
},
},
wantErr: false,
},
{
name: "error response",
resp: &VLMResponse{
Error: "test error",
},
wantErr: true,
},
{
name: "empty actions",
resp: &VLMResponse{},
actions: []ParsedAction{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := processVLMResponse(tt.resp)
result, err := processVLMResponse(tt.actions)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, result)
@ -228,7 +253,7 @@ func TestProcessVLMResponse(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, tt.resp.Actions, result.RealActions)
assert.Equal(t, tt.actions, result.Actions)
})
}
}
@ -237,7 +262,6 @@ func TestSavePositionImg(t *testing.T) {
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
tempFile := t.TempDir() + "/test.png"
params := struct {
InputImgBase64 string
Rect struct {
@ -254,10 +278,12 @@ func TestSavePositionImg(t *testing.T) {
X: 100,
Y: 100,
},
OutputPath: tempFile,
OutputPath: "testdata/output.png",
}
err = SavePositionImg(params)
assert.NoError(t, err)
// TODO: Add more assertions when SavePositionImg is implemented
require.NoError(t, err)
// cleanup
defer os.Remove(params.OutputPath)
}

27
planner/prompt-ui-tars.go Normal file
View File

@ -0,0 +1,27 @@
package planner
const uiTarsPlanningPrompt = `
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
Thought: ...
Action: ...
## Action Space
click(start_box='[x1, y1, x2, y2]')
left_double(start_box='[x1, y1, x2, y2]')
right_single(start_box='[x1, y1, x2, y2]')
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
hotkey(key='')
type(content='') #If you want to submit your input, use "\n" at the end of content.
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use Chinese in Thought part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
## User Instruction
`

BIN
planner/testdata/xhs-feed.jpeg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 649 KiB

View File

@ -8,20 +8,12 @@ import (
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"`
ConversationHistory []*schema.Message `json:"conversation_history"`
Size Size `json:"size"`
}
// Size represents the dimensions of a screen
type Size struct {
Width int `json:"width"`
Height int `json:"height"`
}
// PlanningResult represents the result of planning
type PlanningResult struct {
Actions []interface{} `json:"actions"`
RealActions []ParsedAction `json:"real_actions"`
ActionSummary string `json:"action_summary"`
Actions []ParsedAction `json:"actions"`
ActionSummary string `json:"summary"`
}
// VLMResponse represents the response from the Vision Language Model