fix: convertCoordinateAction

This commit is contained in:
lilong.129 2025-03-20 18:02:35 +08:00
parent 3801ffb744
commit da0bdc4fe5
4 changed files with 154 additions and 26 deletions

View File

@ -1 +1 @@
v5.0.0-beta-2503201423
v5.0.0-beta-2503201802

View File

@ -2,6 +2,7 @@ package ai
import (
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/uixt/types"
)
type ILLMService interface {
@ -22,6 +23,7 @@ func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
type PlanningOptions struct {
UserInstruction string `json:"user_instruction"`
ConversationHistory []*schema.Message `json:"conversation_history"`
Size types.Size `json:"size"`
}
// PlanningResult represents the result of planning

View File

@ -9,12 +9,14 @@ import (
"image/color"
"image/draw"
"image/png"
"math"
"os"
"strings"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
@ -65,7 +67,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
}
// parse result
result, err := p.parseResult(resp)
result, err := p.parseResult(resp, opts.Size)
if err != nil {
return nil, errors.Wrap(err, "parse result failed")
}
@ -135,15 +137,15 @@ func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error)
return resp, nil
}
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
// parse response
actions, err := p.parser.Parse(msg.Content)
parseActions, err := p.parser.Parse(msg.Content)
if err != nil {
return nil, fmt.Errorf("failed to parse actions: %w", err)
}
// process response
result, err := processVLMResponse(actions)
result, err := processVLMResponse(parseActions, size)
if err != nil {
return nil, errors.Wrap(err, "process VLM response failed")
}
@ -152,7 +154,7 @@ func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
}
// processVLMResponse processes the VLM response and converts it to PlanningResult
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
func processVLMResponse(actions []ParsedAction, size types.Size) (*PlanningResult, error) {
log.Info().Msg("processing VLM response...")
if len(actions) == 0 {
@ -163,11 +165,17 @@ func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
for i := range actions {
// validate action type
switch actions[i].ActionType {
case "click", "left_double", "right_single":
validateCoordinateAction(&actions[i], "startBox")
case "click":
if err := convertCoordinateAction(&actions[i], "startBox", size); err != nil {
return nil, errors.Wrap(err, "convert coordinate action failed")
}
case "drag":
validateCoordinateAction(&actions[i], "startBox")
validateCoordinateAction(&actions[i], "endBox")
if err := convertCoordinateAction(&actions[i], "startBox", size); err != nil {
return nil, errors.Wrap(err, "convert coordinate action failed")
}
if err := convertCoordinateAction(&actions[i], "endBox", size); err != nil {
return nil, errors.Wrap(err, "convert coordinate action failed")
}
case "type":
validateTypeContent(&actions[i])
case "wait", "finished", "call_user":
@ -221,9 +229,38 @@ func extractActionSummary(actions []ParsedAction) string {
}
}
// validateCoordinateAction 验证坐标类动作
func validateCoordinateAction(action *ParsedAction, boxField string) {
// TODO
func convertCoordinateAction(action *ParsedAction, boxField string, size types.Size) error {
// The model generates a 2D coordinate output that represents relative positions.
// To convert these values to image-relative coordinates, divide each component by 1000 to obtain values in the range [0,1].
// The absolute coordinates required by the Action can be calculated by:
// - X absolute = X relative × image width / 1000
// - Y absolute = Y relative × image height / 1000
// get image width and height
imageWidth := size.Width
imageHeight := size.Height
box := action.ActionInputs[boxField]
coords, ok := box.([]float64)
if !ok {
log.Error().Interface("inputs", action.ActionInputs).Msg("invalid action inputs")
return fmt.Errorf("invalid action inputs")
}
if len(coords) == 2 {
coords[0] = math.Round((coords[0]/1000*float64(imageWidth))*10) / 10
coords[1] = math.Round((coords[1]/1000*float64(imageHeight))*10) / 10
} else if len(coords) == 4 {
coords[0] = math.Round((coords[0]/1000*float64(imageWidth))*10) / 10
coords[1] = math.Round((coords[1]/1000*float64(imageHeight))*10) / 10
coords[2] = math.Round((coords[2]/1000*float64(imageWidth))*10) / 10
coords[3] = math.Round((coords[3]/1000*float64(imageHeight))*10) / 10
} else {
log.Error().Interface("inputs", action.ActionInputs).Msg("invalid action inputs")
return fmt.Errorf("invalid action inputs")
}
return nil
}
// validateTypeContent 验证输入文本内容
@ -303,11 +340,32 @@ func SavePositionImg(params struct {
}
// loadImage loads image and returns base64 encoded string
func loadImage(imagePath string) (base64Str string, err error) {
imageData, err := os.ReadFile(imagePath)
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
// Read the image file
imageFile, err := os.Open(imagePath)
if err != nil {
return "", err
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
}
base64Str = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
return
defer imageFile.Close()
// Decode the image to get its resolution
imageData, format, err := image.Decode(imageFile)
if err != nil {
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
}
// Get the resolution of the image
width := imageData.Bounds().Dx()
height := imageData.Bounds().Dy()
size = types.Size{Width: width, Height: height}
// Convert image to base64
buf := new(bytes.Buffer)
if err := png.Encode(buf, imageData); err != nil {
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
}
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
base64.StdEncoding.EncodeToString(buf.Bytes()))
return base64Str, size, nil
}

View File

@ -6,12 +6,13 @@ import (
"testing"
"github.com/cloudwego/eino/schema"
"github.com/httprunner/httprunner/v5/uixt/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVLMPlanning(t *testing.T) {
imageBase64, err := loadImage("testdata/llk_1.png")
imageBase64, size, err := loadImage("testdata/llk_1.png")
require.NoError(t, err)
userInstruction := `连连看是一款经典的益智消除类小游戏通常以图案或图标为主要元素以下是连连看的基本规则说明
@ -25,7 +26,7 @@ func TestVLMPlanning(t *testing.T) {
5. 得分机制: 每成功连接并消除一对图案玩家会获得相应的分数完成游戏后根据剩余时间和消除效率计算总分
6. 关卡设计: 游戏可能包含多个关卡随着关卡的推进图案的复杂度和数量会增加`
userInstruction += "\n\n请基于以上游戏规则请依次点击两个可消除的相同图案"
userInstruction += "\n\n请基于以上游戏规则给出下一步可点击的两个图标坐标"
planner, err := NewPlanner(context.Background())
require.NoError(t, err)
@ -45,6 +46,7 @@ func TestVLMPlanning(t *testing.T) {
},
},
},
Size: size,
}
// 执行规划
@ -62,7 +64,7 @@ func TestVLMPlanning(t *testing.T) {
// 根据动作类型验证参数
switch action.ActionType {
case ActionTypeClick:
case "click", "drag", "left_double", "right_single", "scroll":
// 这些动作需要验证坐标
assert.NotEmpty(t, action.ActionInputs["startBox"])
@ -76,6 +78,14 @@ func TestVLMPlanning(t *testing.T) {
assert.GreaterOrEqual(t, coord, float64(0))
}
case "type":
// 验证文本内容
assert.NotEmpty(t, action.ActionInputs["content"])
case "hotkey":
// 验证按键
assert.NotEmpty(t, action.ActionInputs["key"])
case "wait", "finished", "call_user":
// 这些动作不需要额外参数
@ -85,10 +95,13 @@ func TestVLMPlanning(t *testing.T) {
}
func TestXHSPlanning(t *testing.T) {
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
err := loadEnv()
require.NoError(t, err)
userInstruction := `点击第二个帖子的作者头像`
imageBase64, size, err := loadImage("testdata/xhs-feed.jpeg")
require.NoError(t, err)
userInstruction := "点击第二个帖子的作者头像"
planner, err := NewPlanner(context.Background())
require.NoError(t, err)
@ -108,6 +121,7 @@ func TestXHSPlanning(t *testing.T) {
},
},
},
Size: size,
}
// 执行规划
@ -122,10 +136,41 @@ func TestXHSPlanning(t *testing.T) {
action := result.NextActions[0]
assert.NotEmpty(t, action.ActionType)
assert.NotEmpty(t, action.Thought)
// 根据动作类型验证参数
switch action.ActionType {
case "click", "drag", "left_double", "right_single", "scroll":
// 这些动作需要验证坐标
assert.NotEmpty(t, action.ActionInputs["startBox"])
// 验证坐标格式
coords, ok := action.ActionInputs["startBox"].([]float64)
require.True(t, ok)
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
// 验证坐标范围
for _, coord := range coords {
assert.GreaterOrEqual(t, coord, float64(0))
}
case "type":
// 验证文本内容
assert.NotEmpty(t, action.ActionInputs["content"])
case "hotkey":
// 验证按键
assert.NotEmpty(t, action.ActionInputs["key"])
case "wait", "finished", "call_user":
// 这些动作不需要额外参数
default:
t.Fatalf("未知的动作类型: %s", action.ActionType)
}
}
func TestValidateInput(t *testing.T) {
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
imageBase64, size, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
tests := []struct {
@ -150,6 +195,7 @@ func TestValidateInput(t *testing.T) {
},
},
},
Size: size,
},
wantErr: nil,
},
@ -163,6 +209,7 @@ func TestValidateInput(t *testing.T) {
Content: "",
},
},
Size: size,
},
wantErr: ErrEmptyInstruction,
},
@ -184,6 +231,7 @@ func TestValidateInput(t *testing.T) {
Content: "no image",
},
},
Size: size,
},
wantErr: ErrInvalidImageData,
},
@ -228,9 +276,13 @@ func TestProcessVLMResponse(t *testing.T) {
},
}
size := types.Size{
Width: 1000,
Height: 1000,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := processVLMResponse(tt.actions)
result, err := processVLMResponse(tt.actions, size)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, result)
@ -245,7 +297,7 @@ func TestProcessVLMResponse(t *testing.T) {
}
func TestSavePositionImg(t *testing.T) {
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
imageBase64, _, err := loadImage("testdata/popup_risk_warning.png")
require.NoError(t, err)
params := struct {
@ -273,3 +325,19 @@ func TestSavePositionImg(t *testing.T) {
// cleanup
defer os.Remove(params.OutputPath)
}
func TestLoadImage(t *testing.T) {
// Test PNG image
pngBase64, pngSize, err := loadImage("testdata/llk_1.png")
require.NoError(t, err)
assert.NotEmpty(t, pngBase64)
assert.Greater(t, pngSize.Width, 0)
assert.Greater(t, pngSize.Height, 0)
// Test JPEG image
jpegBase64, jpegSize, err := loadImage("testdata/xhs-feed.jpeg")
require.NoError(t, err)
assert.NotEmpty(t, jpegBase64)
assert.Greater(t, jpegSize.Width, 0)
assert.Greater(t, jpegSize.Height, 0)
}