refactor: ui-tars planner
This commit is contained in:
parent
6c74727c44
commit
7da305f577
|
@ -1 +1 @@
|
||||||
v5.0.0-beta-2503190009
|
v5.0.0-beta-2503191748
|
||||||
|
|
|
@ -3,6 +3,7 @@ package planner
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -89,6 +90,20 @@ func GetEnvConfigInInt(key string, defaultValue int) int {
|
||||||
return intValue
|
return intValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CustomTransport is a custom RoundTripper that adds headers to every request
|
||||||
|
type CustomTransport struct {
|
||||||
|
Transport http.RoundTripper
|
||||||
|
Headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip executes a single HTTP transaction and adds custom headers
|
||||||
|
func (c *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
for key, value := range c.Headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
return c.Transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelConfig get OpenAI config
|
// GetModelConfig get OpenAI config
|
||||||
func GetModelConfig() (*openai.ChatModelConfig, error) {
|
func GetModelConfig() (*openai.ChatModelConfig, error) {
|
||||||
envConfig := &OpenAIInitConfig{
|
envConfig := &OpenAIInitConfig{
|
||||||
|
@ -104,7 +119,13 @@ func GetModelConfig() (*openai.ChatModelConfig, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &openai.ChatModelConfig{
|
config := &openai.ChatModelConfig{
|
||||||
Timeout: defaultTimeout,
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
Transport: &CustomTransport{
|
||||||
|
Transport: http.DefaultTransport,
|
||||||
|
Headers: envConfig.Headers,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
|
if baseURL := GetEnvConfig(EnvOpenAIBaseURL); baseURL != "" {
|
||||||
|
|
|
@ -10,40 +10,33 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewActionParser creates a new ActionParser instance
|
// NewActionParser creates a new ActionParser instance
|
||||||
func NewActionParser(prediction string, factor float64) *ActionParser {
|
func NewActionParser(factor float64) *ActionParser {
|
||||||
return &ActionParser{
|
return &ActionParser{
|
||||||
Prediction: prediction,
|
Factor: factor,
|
||||||
Factor: factor,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ActionParser parses VLM responses and converts them to structured actions
|
// ActionParser parses VLM responses and converts them to structured actions
|
||||||
type ActionParser struct {
|
type ActionParser struct {
|
||||||
Prediction string
|
Factor float64 // TODO
|
||||||
Factor float64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse parses the prediction text and extracts actions
|
// Parse parses the prediction text and extracts actions
|
||||||
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
|
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
|
||||||
// try parsing JSON format
|
// try parsing JSON format, from VLM like GPT-4o
|
||||||
var jsonActions []ParsedAction
|
var jsonActions []ParsedAction
|
||||||
jsonActions, jsonErr := p.parseJSON(predictionText)
|
jsonActions, jsonErr := p.parseJSON(predictionText)
|
||||||
if jsonErr == nil && len(jsonActions) > 0 {
|
if jsonErr == nil {
|
||||||
return jsonActions, nil
|
return jsonActions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// if JSON parsing fails, try parsing Thought/Action format
|
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
|
||||||
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
|
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
|
||||||
if thoughtErr == nil && len(thoughtActions) > 0 {
|
if thoughtErr == nil {
|
||||||
return thoughtActions, nil
|
return thoughtActions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// both parsing methods failed
|
return nil, fmt.Errorf("no valid actions returned from VLM, jsonErr: %v, thoughtErr: %v", jsonErr, thoughtErr)
|
||||||
if jsonErr != nil && thoughtErr != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse VLM response: %v; %v", jsonErr, thoughtErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("no actions returned from VLM")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseJSON tries to parse the response as JSON format
|
// parseJSON tries to parse the response as JSON format
|
||||||
|
@ -92,7 +85,7 @@ func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction
|
||||||
thought = strings.TrimSpace(thoughtMatch[1])
|
thought = strings.TrimSpace(thoughtMatch[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract Action part
|
// extract Action part, e.g. "click(start_box='(552,454)')"
|
||||||
actionMatch := actionRegex.FindStringSubmatch(predictionText)
|
actionMatch := actionRegex.FindStringSubmatch(predictionText)
|
||||||
if len(actionMatch) < 2 {
|
if len(actionMatch) < 2 {
|
||||||
return nil, fmt.Errorf("no action found in the response")
|
return nil, fmt.Errorf("no action found in the response")
|
||||||
|
@ -125,6 +118,7 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
|
||||||
"call_user": regexp.MustCompile(`call_user\(\)`),
|
"call_user": regexp.MustCompile(`call_user\(\)`),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parsedActions := make([]ParsedAction, 0)
|
||||||
for actionType, regex := range actionRegexes {
|
for actionType, regex := range actionRegexes {
|
||||||
matches := regex.FindStringSubmatch(actionText)
|
matches := regex.FindStringSubmatch(actionText)
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
|
@ -183,10 +177,13 @@ func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedActi
|
||||||
// 这些动作没有额外参数
|
// 这些动作没有额外参数
|
||||||
}
|
}
|
||||||
|
|
||||||
return []ParsedAction{action}, nil
|
parsedActions = append(parsedActions, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unknown action format: %s", actionText)
|
if len(parsedActions) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid actions returned from VLM")
|
||||||
|
}
|
||||||
|
return parsedActions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeAction normalizes the coordinates in the action
|
// normalizeAction normalizes the coordinates in the action
|
||||||
|
@ -215,16 +212,14 @@ func (p *ActionParser) normalizeAction(action *ParsedAction) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeCoordinates normalizes the coordinates based on the factor
|
// normalizeCoordinates normalizes the coordinates based on the factor
|
||||||
func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
|
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
|
||||||
var coords []float64
|
|
||||||
|
|
||||||
// check empty string
|
// check empty string
|
||||||
if coordStr == "" {
|
if coordStr == "" {
|
||||||
return "", fmt.Errorf("empty coordinate string")
|
return nil, fmt.Errorf("empty coordinate string")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(coordStr, ",") {
|
if !strings.Contains(coordStr, ",") {
|
||||||
return "", fmt.Errorf("invalid coordinate string: %s", coordStr)
|
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove possible brackets and split coordinates
|
// remove possible brackets and split coordinates
|
||||||
|
@ -236,15 +231,9 @@ func (p *ActionParser) normalizeCoordinates(coordStr string) (string, error) {
|
||||||
jsonStr = "[" + coordStr + "]"
|
jsonStr = "[" + coordStr + "]"
|
||||||
}
|
}
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(jsonStr), &coords)
|
err = json.Unmarshal([]byte(jsonStr), &coords)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse coordinate string: %w", err)
|
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
|
||||||
}
|
}
|
||||||
|
return coords, nil
|
||||||
normalized, err := json.Marshal(coords)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to marshal normalized coordinates: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(normalized), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,38 +20,11 @@ import (
|
||||||
|
|
||||||
// Error types
|
// Error types
|
||||||
var (
|
var (
|
||||||
ErrInvalidInput = fmt.Errorf("invalid input parameters")
|
|
||||||
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
|
ErrEmptyInstruction = fmt.Errorf("user instruction is empty")
|
||||||
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
|
ErrNoConversationHistory = fmt.Errorf("conversation history is empty")
|
||||||
ErrInvalidImageData = fmt.Errorf("invalid image data")
|
ErrInvalidImageData = fmt.Errorf("invalid image data")
|
||||||
)
|
)
|
||||||
|
|
||||||
const uiTarsPlanningPrompt = `
|
|
||||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
|
||||||
|
|
||||||
## Output Format
|
|
||||||
Thought: ...
|
|
||||||
Action: ...
|
|
||||||
|
|
||||||
## Action Space
|
|
||||||
click(start_box='[x1, y1, x2, y2]')
|
|
||||||
left_double(start_box='[x1, y1, x2, y2]')
|
|
||||||
right_single(start_box='[x1, y1, x2, y2]')
|
|
||||||
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
|
||||||
hotkey(key='')
|
|
||||||
type(content='') #If you want to submit your input, use "\n" at the end of content.
|
|
||||||
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
|
||||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
|
||||||
finished()
|
|
||||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
|
||||||
|
|
||||||
## Note
|
|
||||||
- Use Chinese in Thought part.
|
|
||||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
|
|
||||||
|
|
||||||
## User Instruction
|
|
||||||
`
|
|
||||||
|
|
||||||
func NewPlanner(ctx context.Context) (*Planner, error) {
|
func NewPlanner(ctx context.Context) (*Planner, error) {
|
||||||
config, err := GetModelConfig()
|
config, err := GetModelConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -61,36 +34,39 @@ func NewPlanner(ctx context.Context) (*Planner, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
|
return nil, fmt.Errorf("failed to initialize OpenAI model: %w", err)
|
||||||
}
|
}
|
||||||
|
parser := NewActionParser(1000)
|
||||||
return &Planner{
|
return &Planner{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
model: model,
|
model: model,
|
||||||
|
parser: parser,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Planner struct {
|
type Planner struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
model *openai.ChatModel
|
model *openai.ChatModel
|
||||||
|
parser *ActionParser
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start performs UI planning using Vision Language Model
|
// Start performs UI planning using Vision Language Model
|
||||||
func (p *Planner) Start(opts PlanningOptions) (*PlanningResult, error) {
|
func (p *Planner) Start(opts *PlanningOptions) (*PlanningResult, error) {
|
||||||
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
log.Info().Str("user_instruction", opts.UserInstruction).Msg("start VLM planning")
|
||||||
|
|
||||||
// 1. validate input parameters
|
// validate input parameters
|
||||||
if err := validateInput(opts); err != nil {
|
if err := validateInput(opts); err != nil {
|
||||||
return nil, errors.Wrap(err, "validate input parameters failed")
|
return nil, errors.Wrap(err, "validate input parameters failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. call VLM service
|
// call VLM service
|
||||||
resp, err := p.callVLMService(opts)
|
resp, err := p.callVLMService(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "call VLM service failed")
|
return nil, errors.Wrap(err, "call VLM service failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. process response
|
// parse result
|
||||||
result, err := processVLMResponse(resp)
|
result, err := p.parseResult(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "process VLM response failed")
|
return nil, errors.Wrap(err, "parse result failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
|
@ -100,7 +76,7 @@ func (p *Planner) Start(opts PlanningOptions) (*PlanningResult, error) {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateInput(opts PlanningOptions) error {
|
func validateInput(opts *PlanningOptions) error {
|
||||||
if opts.UserInstruction == "" {
|
if opts.UserInstruction == "" {
|
||||||
return ErrEmptyInstruction
|
return ErrEmptyInstruction
|
||||||
}
|
}
|
||||||
|
@ -109,10 +85,6 @@ func validateInput(opts PlanningOptions) error {
|
||||||
return ErrNoConversationHistory
|
return ErrNoConversationHistory
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Size.Width <= 0 || opts.Size.Height <= 0 {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensure at least one image URL
|
// ensure at least one image URL
|
||||||
hasImageURL := false
|
hasImageURL := false
|
||||||
for _, msg := range opts.ConversationHistory {
|
for _, msg := range opts.ConversationHistory {
|
||||||
|
@ -133,14 +105,14 @@ func validateInput(opts PlanningOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasImageURL {
|
if !hasImageURL {
|
||||||
return ErrInvalidInput
|
return ErrInvalidImageData
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// callVLMService makes the actual call to the VLM service
|
// callVLMService makes the actual call to the VLM service
|
||||||
func (p *Planner) callVLMService(opts PlanningOptions) (*VLMResponse, error) {
|
func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error) {
|
||||||
log.Info().Msg("calling VLM service...")
|
log.Info().Msg("calling VLM service...")
|
||||||
|
|
||||||
// prepare prompt
|
// prepare prompt
|
||||||
|
@ -158,87 +130,77 @@ func (p *Planner) callVLMService(opts PlanningOptions) (*VLMResponse, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
|
||||||
// parse response
|
// parse response
|
||||||
content := resp.Content
|
actions, err := p.parser.Parse(msg.Content)
|
||||||
parser := NewActionParser(content, 1000) // 使用与 TypeScript 版本相同的 factor
|
|
||||||
actions, err := parser.Parse(content)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse actions: %w", err)
|
return nil, fmt.Errorf("failed to parse actions: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &VLMResponse{
|
// process response
|
||||||
Actions: actions,
|
result, err := processVLMResponse(actions)
|
||||||
}, nil
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "process VLM response failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// processVLMResponse processes the VLM response and converts it to PlanningResult
|
// processVLMResponse processes the VLM response and converts it to PlanningResult
|
||||||
func processVLMResponse(resp *VLMResponse) (*PlanningResult, error) {
|
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
|
||||||
log.Info().Msg("processing VLM response...")
|
log.Info().Msg("processing VLM response...")
|
||||||
if resp.Error != "" {
|
|
||||||
return nil, fmt.Errorf("VLM error: %s", resp.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Actions) == 0 {
|
if len(actions) == 0 {
|
||||||
return nil, fmt.Errorf("no actions returned from VLM")
|
return nil, fmt.Errorf("no actions returned from VLM")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证和后处理每个动作
|
// validate and post-process each action
|
||||||
for i := range resp.Actions {
|
for i := range actions {
|
||||||
// 验证动作类型
|
// validate action type
|
||||||
switch resp.Actions[i].ActionType {
|
switch actions[i].ActionType {
|
||||||
case "click", "left_double", "right_single":
|
case "click", "left_double", "right_single":
|
||||||
validateCoordinateAction(&resp.Actions[i], "startBox")
|
validateCoordinateAction(&actions[i], "startBox")
|
||||||
case "drag":
|
case "drag":
|
||||||
validateCoordinateAction(&resp.Actions[i], "startBox")
|
validateCoordinateAction(&actions[i], "startBox")
|
||||||
validateCoordinateAction(&resp.Actions[i], "endBox")
|
validateCoordinateAction(&actions[i], "endBox")
|
||||||
case "scroll":
|
case "scroll":
|
||||||
validateCoordinateAction(&resp.Actions[i], "startBox")
|
validateCoordinateAction(&actions[i], "startBox")
|
||||||
validateScrollDirection(&resp.Actions[i])
|
validateScrollDirection(&actions[i])
|
||||||
case "type":
|
case "type":
|
||||||
validateTypeContent(&resp.Actions[i])
|
validateTypeContent(&actions[i])
|
||||||
case "hotkey":
|
case "hotkey":
|
||||||
validateHotkeyAction(&resp.Actions[i])
|
validateHotkeyAction(&actions[i])
|
||||||
case "wait", "finished", "call_user":
|
case "wait", "finished", "call_user":
|
||||||
// 这些动作不需要额外参数
|
// these actions do not need extra parameters
|
||||||
default:
|
default:
|
||||||
log.Printf("警告: 未知的动作类型: %s, 将尝试继续处理", resp.Actions[i].ActionType)
|
log.Printf("warning: unknown action type: %s, will try to continue processing", actions[i].ActionType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取动作摘要
|
// extract action summary
|
||||||
actionSummary := extractActionSummary(resp.Actions)
|
actionSummary := extractActionSummary(actions)
|
||||||
|
|
||||||
// 将ParsedAction转换为接口类型
|
|
||||||
var actions []interface{}
|
|
||||||
for _, action := range resp.Actions {
|
|
||||||
actionMap := map[string]interface{}{
|
|
||||||
"actionType": action.ActionType,
|
|
||||||
"actionInputs": action.ActionInputs,
|
|
||||||
"thought": action.Thought,
|
|
||||||
}
|
|
||||||
actions = append(actions, actionMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PlanningResult{
|
return &PlanningResult{
|
||||||
Actions: actions,
|
Actions: actions,
|
||||||
RealActions: resp.Actions,
|
|
||||||
ActionSummary: actionSummary,
|
ActionSummary: actionSummary,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractActionSummary 从动作中提取摘要
|
// extractActionSummary extracts the summary from the actions
|
||||||
func extractActionSummary(actions []ParsedAction) string {
|
func extractActionSummary(actions []ParsedAction) string {
|
||||||
if len(actions) == 0 {
|
if len(actions) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先使用第一个动作的Thought作为摘要
|
// use the Thought of the first action as summary
|
||||||
if actions[0].Thought != "" {
|
if actions[0].Thought != "" {
|
||||||
return actions[0].Thought
|
return actions[0].Thought
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果没有Thought,则根据动作类型生成摘要
|
// if no Thought, generate summary from action type
|
||||||
action := actions[0]
|
action := actions[0]
|
||||||
switch action.ActionType {
|
switch action.ActionType {
|
||||||
case "click":
|
case "click":
|
||||||
|
@ -274,28 +236,21 @@ func extractActionSummary(actions []ParsedAction) string {
|
||||||
|
|
||||||
// validateCoordinateAction 验证坐标类动作
|
// validateCoordinateAction 验证坐标类动作
|
||||||
func validateCoordinateAction(action *ParsedAction, boxField string) {
|
func validateCoordinateAction(action *ParsedAction, boxField string) {
|
||||||
if box, ok := action.ActionInputs[boxField]; !ok || box == "" {
|
// TODO
|
||||||
// 为空或缺失的坐标设置默认值
|
|
||||||
action.ActionInputs[boxField] = "[0.5, 0.5]"
|
|
||||||
log.Printf("警告: %s动作缺少%s参数, 已设置默认值", action.ActionType, boxField)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateScrollDirection 验证滚动方向
|
// validateScrollDirection 验证滚动方向
|
||||||
func validateScrollDirection(action *ParsedAction) {
|
func validateScrollDirection(action *ParsedAction) {
|
||||||
if direction, ok := action.ActionInputs["direction"].(string); !ok || direction == "" {
|
if direction, ok := action.ActionInputs["direction"].(string); !ok || direction == "" {
|
||||||
// 为空或缺失的方向设置默认值
|
// default to down
|
||||||
action.ActionInputs["direction"] = "down"
|
action.ActionInputs["direction"] = "down"
|
||||||
log.Printf("警告: scroll动作缺少direction参数, 已设置默认值")
|
|
||||||
} else {
|
} else {
|
||||||
// 标准化方向
|
|
||||||
switch strings.ToLower(direction) {
|
switch strings.ToLower(direction) {
|
||||||
case "up", "down", "left", "right":
|
case "up", "down", "left", "right":
|
||||||
// 保持原样
|
// keep original direction
|
||||||
default:
|
default:
|
||||||
// 非标准方向设为默认值
|
|
||||||
action.ActionInputs["direction"] = "down"
|
action.ActionInputs["direction"] = "down"
|
||||||
log.Printf("警告: 非标准滚动方向: %s, 已设置为down", direction)
|
log.Warn().Str("direction", direction).Msg("invalid scroll direction, set to default")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -303,9 +258,9 @@ func validateScrollDirection(action *ParsedAction) {
|
||||||
// validateTypeContent 验证输入文本内容
|
// validateTypeContent 验证输入文本内容
|
||||||
func validateTypeContent(action *ParsedAction) {
|
func validateTypeContent(action *ParsedAction) {
|
||||||
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
|
if content, ok := action.ActionInputs["content"]; !ok || content == "" {
|
||||||
// 为空或缺失的内容设置默认值
|
// default to empty string
|
||||||
action.ActionInputs["content"] = ""
|
action.ActionInputs["content"] = ""
|
||||||
log.Printf("警告: type动作缺少content参数, 已设置为空字符串")
|
log.Warn().Msg("type action missing content parameter, set to default")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ package planner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
|
@ -14,7 +14,6 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
err := loadEnv("testdata/.env")
|
err := loadEnv("testdata/.env")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
|
||||||
imageBase64, err := loadImage("testdata/llk_1.png")
|
imageBase64, err := loadImage("testdata/llk_1.png")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -29,12 +28,12 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
|
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
|
||||||
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。`
|
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。`
|
||||||
|
|
||||||
userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标"
|
userInstruction += "\n\n请基于以上游戏规则,请先点击第一个图标"
|
||||||
|
|
||||||
planner, err := NewPlanner(context.Background())
|
planner, err := NewPlanner(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
opts := PlanningOptions{
|
opts := &PlanningOptions{
|
||||||
UserInstruction: userInstruction,
|
UserInstruction: userInstruction,
|
||||||
ConversationHistory: []*schema.Message{
|
ConversationHistory: []*schema.Message{
|
||||||
{
|
{
|
||||||
|
@ -49,10 +48,6 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Size: Size{
|
|
||||||
Width: 1920,
|
|
||||||
Height: 1080,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行规划
|
// 执行规划
|
||||||
|
@ -61,10 +56,10 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
// 验证结果
|
// 验证结果
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotEmpty(t, result.RealActions)
|
require.NotEmpty(t, result.Actions)
|
||||||
|
|
||||||
// 验证动作
|
// 验证动作
|
||||||
action := result.RealActions[0]
|
action := result.Actions[0]
|
||||||
assert.NotEmpty(t, action.ActionType)
|
assert.NotEmpty(t, action.ActionType)
|
||||||
assert.NotEmpty(t, action.Thought)
|
assert.NotEmpty(t, action.Thought)
|
||||||
|
|
||||||
|
@ -75,15 +70,13 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
assert.NotEmpty(t, action.ActionInputs["startBox"])
|
assert.NotEmpty(t, action.ActionInputs["startBox"])
|
||||||
|
|
||||||
// 验证坐标格式
|
// 验证坐标格式
|
||||||
var coords []float64
|
coords, ok := action.ActionInputs["startBox"].([]float64)
|
||||||
err = json.Unmarshal([]byte(action.ActionInputs["startBox"].(string)), &coords)
|
require.True(t, ok)
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
|
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
|
||||||
|
|
||||||
// 验证坐标范围
|
// 验证坐标范围
|
||||||
for _, coord := range coords {
|
for _, coord := range coords {
|
||||||
assert.GreaterOrEqual(t, coord, float64(0))
|
assert.GreaterOrEqual(t, coord, float64(0))
|
||||||
assert.LessOrEqual(t, coord, float64(1920)) // 最大屏幕宽度
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "type":
|
case "type":
|
||||||
|
@ -102,18 +95,61 @@ func TestVLMPlanning(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestXHSPlanning(t *testing.T) {
|
||||||
|
err := loadEnv("testdata/.env")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userInstruction := `点击第二个帖子的作者头像`
|
||||||
|
|
||||||
|
planner, err := NewPlanner(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
opts := &PlanningOptions{
|
||||||
|
UserInstruction: userInstruction,
|
||||||
|
ConversationHistory: []*schema.Message{
|
||||||
|
{
|
||||||
|
Role: schema.User,
|
||||||
|
MultiContent: []schema.ChatMessagePart{
|
||||||
|
{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageURL: &schema.ChatMessageImageURL{
|
||||||
|
URL: imageBase64,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行规划
|
||||||
|
result, err := planner.Start(opts)
|
||||||
|
|
||||||
|
// 验证结果
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotEmpty(t, result.Actions)
|
||||||
|
|
||||||
|
// 验证动作
|
||||||
|
action := result.Actions[0]
|
||||||
|
assert.NotEmpty(t, action.ActionType)
|
||||||
|
assert.NotEmpty(t, action.Thought)
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateInput(t *testing.T) {
|
func TestValidateInput(t *testing.T) {
|
||||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
opts PlanningOptions
|
opts *PlanningOptions
|
||||||
wantErr error
|
wantErr error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid input",
|
name: "valid input",
|
||||||
opts: PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "点击继续使用按钮",
|
UserInstruction: "点击继续使用按钮",
|
||||||
ConversationHistory: []*schema.Message{
|
ConversationHistory: []*schema.Message{
|
||||||
{
|
{
|
||||||
|
@ -128,13 +164,12 @@ func TestValidateInput(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Size: Size{Width: 100, Height: 100},
|
|
||||||
},
|
},
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty instruction",
|
name: "empty instruction",
|
||||||
opts: PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "",
|
UserInstruction: "",
|
||||||
ConversationHistory: []*schema.Message{
|
ConversationHistory: []*schema.Message{
|
||||||
{
|
{
|
||||||
|
@ -142,32 +177,29 @@ func TestValidateInput(t *testing.T) {
|
||||||
Content: "",
|
Content: "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Size: Size{Width: 100, Height: 100},
|
|
||||||
},
|
},
|
||||||
wantErr: ErrEmptyInstruction,
|
wantErr: ErrEmptyInstruction,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty conversation history",
|
name: "empty conversation history",
|
||||||
opts: PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "点击立即卸载按钮",
|
UserInstruction: "点击立即卸载按钮",
|
||||||
ConversationHistory: []*schema.Message{},
|
ConversationHistory: []*schema.Message{},
|
||||||
Size: Size{Width: 100, Height: 100},
|
|
||||||
},
|
},
|
||||||
wantErr: ErrNoConversationHistory,
|
wantErr: ErrNoConversationHistory,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid size",
|
name: "invalid image data",
|
||||||
opts: PlanningOptions{
|
opts: &PlanningOptions{
|
||||||
UserInstruction: "勾选不再提示选项",
|
UserInstruction: "点击继续使用按钮",
|
||||||
ConversationHistory: []*schema.Message{
|
ConversationHistory: []*schema.Message{
|
||||||
{
|
{
|
||||||
Role: schema.User,
|
Role: schema.User,
|
||||||
Content: "",
|
Content: "no image",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Size: Size{Width: 0, Height: 0},
|
|
||||||
},
|
},
|
||||||
wantErr: ErrInvalidInput,
|
wantErr: ErrInvalidImageData,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,6 +208,7 @@ func TestValidateInput(t *testing.T) {
|
||||||
err := validateInput(tt.opts)
|
err := validateInput(tt.opts)
|
||||||
if tt.wantErr != nil {
|
if tt.wantErr != nil {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, tt.wantErr, err)
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
@ -186,40 +219,32 @@ func TestValidateInput(t *testing.T) {
|
||||||
func TestProcessVLMResponse(t *testing.T) {
|
func TestProcessVLMResponse(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
resp *VLMResponse
|
actions []ParsedAction
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid response",
|
name: "valid response",
|
||||||
resp: &VLMResponse{
|
actions: []ParsedAction{
|
||||||
Actions: []ParsedAction{
|
{
|
||||||
{
|
ActionType: "click",
|
||||||
ActionType: "click",
|
ActionInputs: map[string]interface{}{
|
||||||
ActionInputs: map[string]interface{}{
|
"startBox": []float64{0.5, 0.5},
|
||||||
"startBox": "[0.5, 0.5]",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
|
Thought: "点击中心位置",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "error response",
|
|
||||||
resp: &VLMResponse{
|
|
||||||
Error: "test error",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "empty actions",
|
name: "empty actions",
|
||||||
resp: &VLMResponse{},
|
actions: []ParsedAction{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, err := processVLMResponse(tt.resp)
|
result, err := processVLMResponse(tt.actions)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, result)
|
assert.Nil(t, result)
|
||||||
|
@ -228,7 +253,7 @@ func TestProcessVLMResponse(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, result)
|
assert.NotNil(t, result)
|
||||||
assert.Equal(t, tt.resp.Actions, result.RealActions)
|
assert.Equal(t, tt.actions, result.Actions)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -237,7 +262,6 @@ func TestSavePositionImg(t *testing.T) {
|
||||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tempFile := t.TempDir() + "/test.png"
|
|
||||||
params := struct {
|
params := struct {
|
||||||
InputImgBase64 string
|
InputImgBase64 string
|
||||||
Rect struct {
|
Rect struct {
|
||||||
|
@ -254,10 +278,12 @@ func TestSavePositionImg(t *testing.T) {
|
||||||
X: 100,
|
X: 100,
|
||||||
Y: 100,
|
Y: 100,
|
||||||
},
|
},
|
||||||
OutputPath: tempFile,
|
OutputPath: "testdata/output.png",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = SavePositionImg(params)
|
err = SavePositionImg(params)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// TODO: Add more assertions when SavePositionImg is implemented
|
|
||||||
|
// cleanup
|
||||||
|
defer os.Remove(params.OutputPath)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
package planner
|
||||||
|
|
||||||
|
const uiTarsPlanningPrompt = `
|
||||||
|
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
Thought: ...
|
||||||
|
Action: ...
|
||||||
|
|
||||||
|
## Action Space
|
||||||
|
click(start_box='[x1, y1, x2, y2]')
|
||||||
|
left_double(start_box='[x1, y1, x2, y2]')
|
||||||
|
right_single(start_box='[x1, y1, x2, y2]')
|
||||||
|
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
||||||
|
hotkey(key='')
|
||||||
|
type(content='') #If you want to submit your input, use "\n" at the end of content.
|
||||||
|
scroll(start_box='[x1, y1, x2, y2]', direction='down or up or right or left')
|
||||||
|
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||||
|
finished()
|
||||||
|
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||||
|
|
||||||
|
## Note
|
||||||
|
- Use Chinese in Thought part.
|
||||||
|
- Write a small plan and finally summarize your next action (with its target element) in one sentence in Thought part.
|
||||||
|
|
||||||
|
## User Instruction
|
||||||
|
`
|
Binary file not shown.
After Width: | Height: | Size: 649 KiB |
|
@ -8,20 +8,12 @@ import (
|
||||||
type PlanningOptions struct {
|
type PlanningOptions struct {
|
||||||
UserInstruction string `json:"user_instruction"`
|
UserInstruction string `json:"user_instruction"`
|
||||||
ConversationHistory []*schema.Message `json:"conversation_history"`
|
ConversationHistory []*schema.Message `json:"conversation_history"`
|
||||||
Size Size `json:"size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size represents the dimensions of a screen
|
|
||||||
type Size struct {
|
|
||||||
Width int `json:"width"`
|
|
||||||
Height int `json:"height"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PlanningResult represents the result of planning
|
// PlanningResult represents the result of planning
|
||||||
type PlanningResult struct {
|
type PlanningResult struct {
|
||||||
Actions []interface{} `json:"actions"`
|
Actions []ParsedAction `json:"actions"`
|
||||||
RealActions []ParsedAction `json:"real_actions"`
|
ActionSummary string `json:"summary"`
|
||||||
ActionSummary string `json:"action_summary"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VLMResponse represents the response from the Vision Language Model
|
// VLMResponse represents the response from the Vision Language Model
|
||||||
|
|
Loading…
Reference in New Issue