refactor: ui-tars planner
This commit is contained in:
parent
6c74727c44
commit
7da305f577
|
@ -1 +1 @@
|
|||
v5.0.0-beta-2503190009
|
||||
v5.0.0-beta-2503191748
|
||||
|
|
|
@ -3,6 +3,7 @@ package planner
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -89,6 +90,20 @@ func GetEnvConfigInInt(key string, defaultValue int) int {
|
|||
return intValue
|
||||
}
|
||||
|
||||
// CustomTransport is a custom RoundTripper that adds headers to every request
|
||||
type CustomTransport struct {
|
||||
Transport http.RoundTripper
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction and adds custom headers
|
||||
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for key, value := range c.Headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
return c.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// GetModelConfig get OpenAI config
|
||||
func GetModelConfig() (*openai.ChatModelConfig, error) {
|
||||
envConfig := &OpenAIInitConfig{
|
||||
|
@ -104,7 +119,13 @@ func GetModelConfig() (*openai.ChatModelConfig, error) {
|
|||
}
|
||||
|
||||
config := &openai.ChatModelConfig{
|
||||
Timeout: defaultTimeout,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: &CustomTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Headers: envConfig.Headers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
|
||||
|
|
|
@ -10,40 +10,33 @@ import (
|
|||
)
|
||||
|
||||
// NewActionParser creates a new ActionParser instance
|
||||
func NewActionParser(prediction string, factor float64) *ActionParser {
|
||||
func NewActionParser(factor float64) *ActionParser {
|
||||
return &ActionParser{
|
||||
Prediction: prediction,
|
||||
Factor: factor,
|
||||
Factor: factor,
|
||||
}
|
||||
}
|
||||
|
||||
// ActionParser parses VLM responses and converts them to structured actions
|
||||
type ActionParser struct {
|
||||
Prediction string
|
||||
Factor float64
|
||||
Factor float64 // TODO
|
||||
}
|
||||
|
||||
// Parse parses the prediction text and extracts actions
|
||||
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
|
||||
// try parsing JSON format
|
||||
// try parsing JSON format, from VLM like GPT-4o
|
||||
var jsonActions []ParsedAction
|
||||
jsonActions, jsonErr := p.parseJSON(predictionText)
|
||||
if jsonErr == nil && len(jsonActions) > 0 {
|
||||
if jsonErr == nil {
|
||||
return jsonActions, nil
|
||||
}
|
||||
|
||||
// if JSON parsing fails, try parsing Thought/Action format
|
||||
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
|
||||
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
|
||||
if thoughtErr == nil && len(thoughtActions) > 0 {
|
||||
if thoughtErr == nil {
|
||||
return thoughtActions, nil
|
||||
}
|
||||
|
||||
// both parsing methods failed
|
||||
if jsonErr != nil && thoughtErr != nil {
|
||||
return nil, fmt.Errorf("failed to parse VLM response: %v; %v", jsonErr, thoughtErr)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no actions returned from VLM")
|
||||
return nil, fmt.Errorf("no valid actions returned from VLM, jsonErr: %v, thoughtErr: %v", jsonErr, thoughtErr)
|
||||
}
|
||||
|
||||
// parseJSON tries to parse the response as JSON format
|
||||
|
@ -92,7 +85,7 @@ func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction
|
|||
thought = strings.TrimSpace(thoughtMatch[1])
|
||||
}
|
||||
|
||||
// extract Action part
|
||||
// extract Action part, e.g. "click(start_box='(552,454)')"
|
||||
actionMatch := actionRegex.FindStringSubmatch(predictionText)
|
||||
if len(actionMatch) < 2 {
|
||||
return nil, fmt.Errorf("no action found in the response")
|
||||
|
@ -125,6 +118,7 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
|
|||
"call_user": regexp.MustCompile(`call_user\(\)`),
|
||||
}
|
||||
|
||||
parsedActions := make([]ParsedAction, 0)
|
||||
for actionType, regex := range actionRegexes {
|
||||
matches := regex.FindStringSubmatch(actionText)
|
||||
if len(matches) == 0 {
|
||||
|
@ -183,10 +177,13 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
|
|||
// 这些动作没有额外参数
|
||||
}
|
||||
|
||||
return []ParsedAction{action}, nil
|
||||
parsedActions = append(parsedActions, action)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown action format: %s", actionText)
|
||||
if len(parsedActions) == 0 {
|
||||
return nil, fmt.Errorf("no valid actions returned from VLM")
|
||||
}
|
||||
return parsedActions, nil
|
||||
}
|
||||
|
||||
// normalizeAction normalizes the coordinates in the action
|
||||
|
@ -215,16 +212,14 @@ func (p *ActionParser) normalizeAction(action *ParsedAction) error {
|
|||
}
|
||||
|
||||
// normalizeCoordinates normalizes the coordinates based on the factor
|
||||
func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
|
||||
var coords []float64
|
||||
|
||||
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
|
||||
// check empty string
|
||||
if coordStr == "" {
|
||||
return "", fmt.Errorf("empty coordinate string")
|
||||
return nil, fmt.Errorf("empty coordinate string")
|
||||
}
|
||||
|
||||
if !strings.Contains(coordStr, ",") {
|
||||
return "", fmt.Errorf("invalid coordinate string: %s", coordStr)
|
||||
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
|
||||
}
|
||||
|
||||
// remove possible brackets and split coordinates
|
||||
|
@ -236,15 +231,9 @@ func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
|
|||
jsonStr = "[" + coordStr + "]"
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(jsonStr), &coords)
|
||||
err = json.Unmarshal([]byte(jsonStr), &coords)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse coordinate string: %w", err)
|
||||
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
|
||||
}
|
||||
|
||||
normalized, err := json.Marshal(coords)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal normalized coordinates: %w", err)
|
||||
}
|
||||
|
||||
return string(normalized), nil
|
||||
return coords, nil
|
||||
}
|
||||
|
|
|
@ -20,38 +20,11 @@ import (
|
|||
|
||||
// Error types
|
||||
var (
|
||||
ErrInvalidInput = fmt.Errorf("invalid input parameters")
|
||||
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
|
||||
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
|
||||
ErrInvalidImageData = fmt.Errorf("invalid image data")
|
||||
)
|
||||
|
||||
const uiTarsPlanningPrompt = `
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
Thought: ...
|
||||
Action: ...
|
||||
|
||||
## Action Space
|
||||
click(start_box='[x1, y1, x2, y2]')
|
||||
left_double(start_box='[x1, y1, x2, y2]')
|
||||
right_single(start_box='[x1, y1, x2, y2]')
|
||||
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\n" at the end of content.
|
||||
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
|
||||
## Note
|
||||
- Use Chinese in Thought part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
|
||||
|
||||
## User Instruction
|
||||
`
|
||||
|
||||
func NewPlanner(ctx context.Context) (*Planner, error) {
|
||||
config, err := GetModelConfig()
|
||||
if err != nil {
|
||||
|
@ -61,36 +34,39 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
|
||||
}
|
||||
parser := NewActionParser(1000)
|
||||
return &Planner{
|
||||
ctx: ctx,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
model: model,
|
||||
parser: parser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Planner struct {
|
||||
ctx context.Context
|
||||
model *openai.ChatModel
|
||||
ctx context.Context
|
||||
model *openai.ChatModel
|
||||
parser *ActionParser
|
||||
}
|
||||
|
||||
// Start performs UI planning using Vision Language Model
|
||||
func (p *Planner) Start(opts PlanningOptions) (*PlanningResult, error) {
|
||||
func (p *Planner) Start(opts *PlanningOptions) (*PlanningResult, error) {
|
||||
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
||||
|
||||
// 1. validate input parameters
|
||||
// validate input parameters
|
||||
if err := validateInput(opts); err != nil {
|
||||
return nil, errors.Wrap(err, "validate input parameters failed")
|
||||
}
|
||||
|
||||
// 2. call VLM service
|
||||
// call VLM service
|
||||
resp, err := p.callVLMService(opts)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "call VLM service failed")
|
||||
}
|
||||
|
||||
// 3. process response
|
||||
result, err := processVLMResponse(resp)
|
||||
// parse result
|
||||
result, err := p.parseResult(resp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "process VLM response failed")
|
||||
return nil, errors.Wrap(err, "parse result failed")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
|
@ -100,7 +76,7 @@ func (p *Planner) Start(opts PlanningOptions) (*PlanningResult, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func validateInput(opts PlanningOptions) error {
|
||||
func validateInput(opts *PlanningOptions) error {
|
||||
if opts.UserInstruction == "" {
|
||||
return ErrEmptyInstruction
|
||||
}
|
||||
|
@ -109,10 +85,6 @@ func validateInput(opts PlanningOptions) error {
|
|||
return ErrNoConversationHistory
|
||||
}
|
||||
|
||||
if opts.Size.Width <= 0 || opts.Size.Height <= 0 {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
|
||||
// ensure at least one image URL
|
||||
hasImageURL := false
|
||||
for _, msg := range opts.ConversationHistory {
|
||||
|
@ -133,14 +105,14 @@ func validateInput(opts PlanningOptions) error {
|
|||
}
|
||||
|
||||
if !hasImageURL {
|
||||
return ErrInvalidInput
|
||||
return ErrInvalidImageData
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callVLMService makes the actual call to the VLM service
|
||||
func (p *Planner) callVLMService(opts PlanningOptions) (*VLMResponse, error) {
|
||||
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
|
||||
log.Info().Msg("calling VLM service...")
|
||||
|
||||
// prepare prompt
|
||||
|
@ -158,87 +130,77 @@ func (p *Planner) callVLMService(opts PlanningOptions) (*VLMResponse, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
|
||||
// parse response
|
||||
content := resp.Content
|
||||
parser := NewActionParser(content, 1000) // 使用与 TypeScript 版本相同的 factor
|
||||
actions, err := parser.Parse(content)
|
||||
actions, err := p.parser.Parse(msg.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse actions: %w", err)
|
||||
}
|
||||
|
||||
return &VLMResponse{
|
||||
Actions: actions,
|
||||
}, nil
|
||||
// process response
|
||||
result, err := processVLMResponse(actions)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "process VLM response failed")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// processVLMResponse processes the VLM response and converts it to PlanningResult
|
||||
func processVLMResponse(resp *VLMResponse) (*PlanningResult, error) {
|
||||
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
|
||||
log.Info().Msg("processing VLM response...")
|
||||
if resp.Error != "" {
|
||||
return nil, fmt.Errorf("VLM error: %s", resp.Error)
|
||||
}
|
||||
|
||||
if len(resp.Actions) == 0 {
|
||||
if len(actions) == 0 {
|
||||
return nil, fmt.Errorf("no actions returned from VLM")
|
||||
}
|
||||
|
||||
// 验证和后处理每个动作
|
||||
for i := range resp.Actions {
|
||||
// 验证动作类型
|
||||
switch resp.Actions[i].ActionType {
|
||||
// validate and post-process each action
|
||||
for i := range actions {
|
||||
// validate action type
|
||||
switch actions[i].ActionType {
|
||||
case "click", "left_double", "right_single":
|
||||
validateCoordinateAction(&resp.Actions[i], "startBox")
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
case "drag":
|
||||
validateCoordinateAction(&resp.Actions[i], "startBox")
|
||||
validateCoordinateAction(&resp.Actions[i], "endBox")
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
validateCoordinateAction(&actions[i], "endBox")
|
||||
case "scroll":
|
||||
validateCoordinateAction(&resp.Actions[i], "startBox")
|
||||
validateScrollDirection(&resp.Actions[i])
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
validateScrollDirection(&actions[i])
|
||||
case "type":
|
||||
validateTypeContent(&resp.Actions[i])
|
||||
validateTypeContent(&actions[i])
|
||||
case "hotkey":
|
||||
validateHotkeyAction(&resp.Actions[i])
|
||||
validateHotkeyAction(&actions[i])
|
||||
case "wait", "finished", "call_user":
|
||||
// 这些动作不需要额外参数
|
||||
// these actions do not need extra parameters
|
||||
default:
|
||||
log.Printf("警告: 未知的动作类型: %s, 将尝试继续处理", resp.Actions[i].ActionType)
|
||||
log.Printf("warning: unknown action type: %s, will try to continue processing", actions[i].ActionType)
|
||||
}
|
||||
}
|
||||
|
||||
// 提取动作摘要
|
||||
actionSummary := extractActionSummary(resp.Actions)
|
||||
|
||||
// 将ParsedAction转换为接口类型
|
||||
var actions []interface{}
|
||||
for _, action := range resp.Actions {
|
||||
actionMap := map[string]interface{}{
|
||||
"actionType": action.ActionType,
|
||||
"actionInputs": action.ActionInputs,
|
||||
"thought": action.Thought,
|
||||
}
|
||||
actions = append(actions, actionMap)
|
||||
}
|
||||
// extract action summary
|
||||
actionSummary := extractActionSummary(actions)
|
||||
|
||||
return &PlanningResult{
|
||||
Actions: actions,
|
||||
RealActions: resp.Actions,
|
||||
ActionSummary: actionSummary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractActionSummary 从动作中提取摘要
|
||||
// extractActionSummary extracts the summary from the actions
|
||||
func extractActionSummary(actions []ParsedAction) string {
|
||||
if len(actions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 优先使用第一个动作的Thought作为摘要
|
||||
// use the Thought of the first action as summary
|
||||
if actions[0].Thought != "" {
|
||||
return actions[0].Thought
|
||||
}
|
||||
|
||||
// 如果没有Thought,则根据动作类型生成摘要
|
||||
// if no Thought, generate summary from action type
|
||||
action := actions[0]
|
||||
switch action.ActionType {
|
||||
case "click":
|
||||
|
@ -274,28 +236,21 @@ func extractActionSummary(actions []ParsedAction) string {
|
|||
|
||||
// validateCoordinateAction 验证坐标类动作
|
||||
func validateCoordinateAction(action *ParsedAction, boxField string) {
|
||||
if box, ok := action.ActionInputs[boxField]; !ok || box == "" {
|
||||
// 为空或缺失的坐标设置默认值
|
||||
action.ActionInputs[boxField] = "[0.5, 0.5]"
|
||||
log.Printf("警告: %s动作缺少%s参数, 已设置默认值", action.ActionType, boxField)
|
||||
}
|
||||
// TODO
|
||||
}
|
||||
|
||||
// validateScrollDirection 验证滚动方向
|
||||
func validateScrollDirection(action *ParsedAction) {
|
||||
if direction, ok := action.ActionInputs["direction"].(string); !ok || direction == "" {
|
||||
// 为空或缺失的方向设置默认值
|
||||
// default to down
|
||||
action.ActionInputs["direction"] = "down"
|
||||
log.Printf("警告: scroll动作缺少direction参数, 已设置默认值")
|
||||
} else {
|
||||
// 标准化方向
|
||||
switch strings.ToLower(direction) {
|
||||
case "up", "down", "left", "right":
|
||||
// 保持原样
|
||||
// keep original direction
|
||||
default:
|
||||
// 非标准方向设为默认值
|
||||
action.ActionInputs["direction"] = "down"
|
||||
log.Printf("警告: 非标准滚动方向: %s, 已设置为down", direction)
|
||||
log.Warn().Str("direction", direction).Msg("invalid scroll direction, set to default")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -303,9 +258,9 @@ func validateScrollDirection(action *ParsedAction) {
|
|||
// validateTypeContent 验证输入文本内容
|
||||
func validateTypeContent(action *ParsedAction) {
|
||||
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
|
||||
// 为空或缺失的内容设置默认值
|
||||
// default to empty string
|
||||
action.ActionInputs["content"] = ""
|
||||
log.Printf("警告: type动作缺少content参数, 已设置为空字符串")
|
||||
log.Warn().Msg("type action missing content parameter, set to default")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ package planner
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
@ -14,7 +14,6 @@ func TestVLMPlanning(t *testing.T) {
|
|||
err := loadEnv("testdata/.env")
|
||||
require.NoError(t, err)
|
||||
|
||||
// imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
imageBase64, err := loadImage("testdata/llk_1.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -29,12 +28,12 @@ func TestVLMPlanning(t *testing.T) {
|
|||
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
|
||||
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。`
|
||||
|
||||
userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标"
|
||||
userInstruction += "\n\n请基于以上游戏规则,请先点击第一个图标"
|
||||
|
||||
planner, err := NewPlanner(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := PlanningOptions{
|
||||
opts := &PlanningOptions{
|
||||
UserInstruction: userInstruction,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
|
@ -49,10 +48,6 @@ func TestVLMPlanning(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Size: Size{
|
||||
Width: 1920,
|
||||
Height: 1080,
|
||||
},
|
||||
}
|
||||
|
||||
// 执行规划
|
||||
|
@ -61,10 +56,10 @@ func TestVLMPlanning(t *testing.T) {
|
|||
// 验证结果
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotEmpty(t, result.RealActions)
|
||||
require.NotEmpty(t, result.Actions)
|
||||
|
||||
// 验证动作
|
||||
action := result.RealActions[0]
|
||||
action := result.Actions[0]
|
||||
assert.NotEmpty(t, action.ActionType)
|
||||
assert.NotEmpty(t, action.Thought)
|
||||
|
||||
|
@ -75,15 +70,13 @@ func TestVLMPlanning(t *testing.T) {
|
|||
assert.NotEmpty(t, action.ActionInputs["startBox"])
|
||||
|
||||
// 验证坐标格式
|
||||
var coords []float64
|
||||
err = json.Unmarshal([]byte(action.ActionInputs["startBox"].(string)), &coords)
|
||||
require.NoError(t, err)
|
||||
coords, ok := action.ActionInputs["startBox"].([]float64)
|
||||
require.True(t, ok)
|
||||
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
|
||||
|
||||
// 验证坐标范围
|
||||
for _, coord := range coords {
|
||||
assert.GreaterOrEqual(t, coord, float64(0))
|
||||
assert.LessOrEqual(t, coord, float64(1920)) // 最大屏幕宽度
|
||||
}
|
||||
|
||||
case "type":
|
||||
|
@ -102,18 +95,61 @@ func TestVLMPlanning(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestXHSPlanning(t *testing.T) {
|
||||
err := loadEnv("testdata/.env")
|
||||
require.NoError(t, err)
|
||||
|
||||
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := `点击第二个帖子的作者头像`
|
||||
|
||||
planner, err := NewPlanner(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := &PlanningOptions{
|
||||
UserInstruction: userInstruction,
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
MultiContent: []schema.ChatMessagePart{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: imageBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 执行规划
|
||||
result, err := planner.Start(opts)
|
||||
|
||||
// 验证结果
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotEmpty(t, result.Actions)
|
||||
|
||||
// 验证动作
|
||||
action := result.Actions[0]
|
||||
assert.NotEmpty(t, action.ActionType)
|
||||
assert.NotEmpty(t, action.Thought)
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts PlanningOptions
|
||||
opts *PlanningOptions
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid input",
|
||||
opts: PlanningOptions{
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击继续使用按钮",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
|
@ -128,13 +164,12 @@ func TestValidateInput(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Size: Size{Width: 100, Height: 100},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "empty instruction",
|
||||
opts: PlanningOptions{
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
|
@ -142,32 +177,29 @@ func TestValidateInput(t *testing.T) {
|
|||
Content: "",
|
||||
},
|
||||
},
|
||||
Size: Size{Width: 100, Height: 100},
|
||||
},
|
||||
wantErr: ErrEmptyInstruction,
|
||||
},
|
||||
{
|
||||
name: "empty conversation history",
|
||||
opts: PlanningOptions{
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击立即卸载按钮",
|
||||
ConversationHistory: []*schema.Message{},
|
||||
Size: Size{Width: 100, Height: 100},
|
||||
},
|
||||
wantErr: ErrNoConversationHistory,
|
||||
},
|
||||
{
|
||||
name: "invalid size",
|
||||
opts: PlanningOptions{
|
||||
UserInstruction: "勾选不再提示选项",
|
||||
name: "invalid image data",
|
||||
opts: &PlanningOptions{
|
||||
UserInstruction: "点击继续使用按钮",
|
||||
ConversationHistory: []*schema.Message{
|
||||
{
|
||||
Role: schema.User,
|
||||
Content: "",
|
||||
Content: "no image",
|
||||
},
|
||||
},
|
||||
Size: Size{Width: 0, Height: 0},
|
||||
},
|
||||
wantErr: ErrInvalidInput,
|
||||
wantErr: ErrInvalidImageData,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -176,6 +208,7 @@ func TestValidateInput(t *testing.T) {
|
|||
err := validateInput(tt.opts)
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@ -186,40 +219,32 @@ func TestValidateInput(t *testing.T) {
|
|||
func TestProcessVLMResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
resp *VLMResponse
|
||||
actions []ParsedAction
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
resp: &VLMResponse{
|
||||
Actions: []ParsedAction{
|
||||
{
|
||||
ActionType: "click",
|
||||
ActionInputs: map[string]interface{}{
|
||||
"startBox": "[0.5, 0.5]",
|
||||
},
|
||||
actions: []ParsedAction{
|
||||
{
|
||||
ActionType: "click",
|
||||
ActionInputs: map[string]interface{}{
|
||||
"startBox": []float64{0.5, 0.5},
|
||||
},
|
||||
Thought: "点击中心位置",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error response",
|
||||
resp: &VLMResponse{
|
||||
Error: "test error",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty actions",
|
||||
resp: &VLMResponse{},
|
||||
actions: []ParsedAction{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processVLMResponse(tt.resp)
|
||||
result, err := processVLMResponse(tt.actions)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
|
@ -228,7 +253,7 @@ func TestProcessVLMResponse(t *testing.T) {
|
|||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tt.resp.Actions, result.RealActions)
|
||||
assert.Equal(t, tt.actions, result.Actions)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -237,7 +262,6 @@ func TestSavePositionImg(t *testing.T) {
|
|||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
tempFile := t.TempDir() + "/test.png"
|
||||
params := struct {
|
||||
InputImgBase64 string
|
||||
Rect struct {
|
||||
|
@ -254,10 +278,12 @@ func TestSavePositionImg(t *testing.T) {
|
|||
X: 100,
|
||||
Y: 100,
|
||||
},
|
||||
OutputPath: tempFile,
|
||||
OutputPath: "testdata/output.png",
|
||||
}
|
||||
|
||||
err = SavePositionImg(params)
|
||||
assert.NoError(t, err)
|
||||
// TODO: Add more assertions when SavePositionImg is implemented
|
||||
require.NoError(t, err)
|
||||
|
||||
// cleanup
|
||||
defer os.Remove(params.OutputPath)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package planner
|
||||
|
||||
const uiTarsPlanningPrompt = `
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
Thought: ...
|
||||
Action: ...
|
||||
|
||||
## Action Space
|
||||
click(start_box='[x1, y1, x2, y2]')
|
||||
left_double(start_box='[x1, y1, x2, y2]')
|
||||
right_single(start_box='[x1, y1, x2, y2]')
|
||||
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\n" at the end of content.
|
||||
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
|
||||
## Note
|
||||
- Use Chinese in Thought part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
|
||||
|
||||
## User Instruction
|
||||
`
|
Binary file not shown.
After Width: | Height: | Size: 649 KiB |
|
@ -8,20 +8,12 @@ import (
|
|||
type PlanningOptions struct {
|
||||
UserInstruction string `json:"user_instruction"`
|
||||
ConversationHistory []*schema.Message `json:"conversation_history"`
|
||||
Size Size `json:"size"`
|
||||
}
|
||||
|
||||
// Size represents the dimensions of a screen
|
||||
type Size struct {
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
// PlanningResult represents the result of planning
|
||||
type PlanningResult struct {
|
||||
Actions []interface{} `json:"actions"`
|
||||
RealActions []ParsedAction `json:"real_actions"`
|
||||
ActionSummary string `json:"action_summary"`
|
||||
Actions []ParsedAction `json:"actions"`
|
||||
ActionSummary string `json:"summary"`
|
||||
}
|
||||
|
||||
// VLMResponse represents the response from the Vision Language Model
|
||||
|
|
Loading…
Reference in New Issue