fix: convertCoordinateAction
This commit is contained in:
parent
3801ffb744
commit
da0bdc4fe5
|
@ -1 +1 @@
|
|||
v5.0.0-beta-2503201423
|
||||
v5.0.0-beta-2503201802
|
||||
|
|
|
@ -2,6 +2,7 @@ package ai
|
|||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
)
|
||||
|
||||
type ILLMService interface {
|
||||
|
@ -22,6 +23,7 @@ func (s openaiLLMService) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
|||
type PlanningOptions struct {
|
||||
UserInstruction string `json:"user_instruction"`
|
||||
ConversationHistory []*schema.Message `json:"conversation_history"`
|
||||
Size types.Size `json:"size"`
|
||||
}
|
||||
|
||||
// PlanningResult represents the result of planning
|
||||
|
|
|
@ -9,12 +9,14 @@ import (
|
|||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -65,7 +67,7 @@ func (p *Planner) Call(opts *PlanningOptions) (*PlanningResult, error) {
|
|||
}
|
||||
|
||||
// parse result
|
||||
result, err := p.parseResult(resp)
|
||||
result, err := p.parseResult(resp, opts.Size)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse result failed")
|
||||
}
|
||||
|
@ -135,15 +137,15 @@ func (p *Planner) callVLMService(opts *PlanningOptions) (*schema.Message, error)
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
|
||||
func (p *Planner) parseResult(msg *schema.Message, size types.Size) (*PlanningResult, error) {
|
||||
// parse response
|
||||
actions, err := p.parser.Parse(msg.Content)
|
||||
parseActions, err := p.parser.Parse(msg.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse actions: %w", err)
|
||||
}
|
||||
|
||||
// process response
|
||||
result, err := processVLMResponse(actions)
|
||||
result, err := processVLMResponse(parseActions, size)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "process VLM response failed")
|
||||
}
|
||||
|
@ -152,7 +154,7 @@ func (p *Planner) parseResult(msg *schema.Message) (*PlanningResult, error) {
|
|||
}
|
||||
|
||||
// processVLMResponse processes the VLM response and converts it to PlanningResult
|
||||
func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
|
||||
func processVLMResponse(actions []ParsedAction, size types.Size) (*PlanningResult, error) {
|
||||
log.Info().Msg("processing VLM response...")
|
||||
|
||||
if len(actions) == 0 {
|
||||
|
@ -163,11 +165,17 @@ func processVLMResponse(actions []ParsedAction) (*PlanningResult, error) {
|
|||
for i := range actions {
|
||||
// validate action type
|
||||
switch actions[i].ActionType {
|
||||
case "click", "left_double", "right_single":
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
case "click":
|
||||
if err := convertCoordinateAction(&actions[i], "startBox", size); err != nil {
|
||||
return nil, errors.Wrap(err, "convert coordinate action failed")
|
||||
}
|
||||
case "drag":
|
||||
validateCoordinateAction(&actions[i], "startBox")
|
||||
validateCoordinateAction(&actions[i], "endBox")
|
||||
if err := convertCoordinateAction(&actions[i], "startBox", size); err != nil {
|
||||
return nil, errors.Wrap(err, "convert coordinate action failed")
|
||||
}
|
||||
if err := convertCoordinateAction(&actions[i], "endBox", size); err != nil {
|
||||
return nil, errors.Wrap(err, "convert coordinate action failed")
|
||||
}
|
||||
case "type":
|
||||
validateTypeContent(&actions[i])
|
||||
case "wait", "finished", "call_user":
|
||||
|
@ -221,9 +229,38 @@ func extractActionSummary(actions []ParsedAction) string {
|
|||
}
|
||||
}
|
||||
|
||||
// validateCoordinateAction 验证坐标类动作
|
||||
func validateCoordinateAction(action *ParsedAction, boxField string) {
|
||||
// TODO
|
||||
func convertCoordinateAction(action *ParsedAction, boxField string, size types.Size) error {
|
||||
// The model generates a 2D coordinate output that represents relative positions.
|
||||
// To convert these values to image-relative coordinates, divide each component by 1000 to obtain values in the range [0,1].
|
||||
// The absolute coordinates required by the Action can be calculated by:
|
||||
// - X absolute = X relative × image width / 1000
|
||||
// - Y absolute = Y relative × image height / 1000
|
||||
|
||||
// get image width and height
|
||||
imageWidth := size.Width
|
||||
imageHeight := size.Height
|
||||
|
||||
box := action.ActionInputs[boxField]
|
||||
coords, ok := box.([]float64)
|
||||
if !ok {
|
||||
log.Error().Interface("inputs", action.ActionInputs).Msg("invalid action inputs")
|
||||
return fmt.Errorf("invalid action inputs")
|
||||
}
|
||||
|
||||
if len(coords) == 2 {
|
||||
coords[0] = math.Round((coords[0]/1000*float64(imageWidth))*10) / 10
|
||||
coords[1] = math.Round((coords[1]/1000*float64(imageHeight))*10) / 10
|
||||
} else if len(coords) == 4 {
|
||||
coords[0] = math.Round((coords[0]/1000*float64(imageWidth))*10) / 10
|
||||
coords[1] = math.Round((coords[1]/1000*float64(imageHeight))*10) / 10
|
||||
coords[2] = math.Round((coords[2]/1000*float64(imageWidth))*10) / 10
|
||||
coords[3] = math.Round((coords[3]/1000*float64(imageHeight))*10) / 10
|
||||
} else {
|
||||
log.Error().Interface("inputs", action.ActionInputs).Msg("invalid action inputs")
|
||||
return fmt.Errorf("invalid action inputs")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTypeContent 验证输入文本内容
|
||||
|
@ -303,11 +340,32 @@ func SavePositionImg(params struct {
|
|||
}
|
||||
|
||||
// loadImage loads image and returns base64 encoded string
|
||||
func loadImage(imagePath string) (base64Str string, err error) {
|
||||
imageData, err := os.ReadFile(imagePath)
|
||||
func loadImage(imagePath string) (base64Str string, size types.Size, err error) {
|
||||
// Read the image file
|
||||
imageFile, err := os.Open(imagePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", types.Size{}, fmt.Errorf("failed to open image file: %w", err)
|
||||
}
|
||||
base64Str = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||
return
|
||||
defer imageFile.Close()
|
||||
|
||||
// Decode the image to get its resolution
|
||||
imageData, format, err := image.Decode(imageFile)
|
||||
if err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Get the resolution of the image
|
||||
width := imageData.Bounds().Dx()
|
||||
height := imageData.Bounds().Dy()
|
||||
size = types.Size{Width: width, Height: height}
|
||||
|
||||
// Convert image to base64
|
||||
buf := new(bytes.Buffer)
|
||||
if err := png.Encode(buf, imageData); err != nil {
|
||||
return "", types.Size{}, fmt.Errorf("failed to encode image to buffer: %w", err)
|
||||
}
|
||||
base64Str = fmt.Sprintf("data:image/%s;base64,%s", format,
|
||||
base64.StdEncoding.EncodeToString(buf.Bytes()))
|
||||
|
||||
return base64Str, size, nil
|
||||
}
|
||||
|
|
|
@ -6,12 +6,13 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/httprunner/httprunner/v5/uixt/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVLMPlanning(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/llk_1.png")
|
||||
imageBase64, size, err := loadImage("testdata/llk_1.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := `连连看是一款经典的益智消除类小游戏,通常以图案或图标为主要元素。以下是连连看的基本规则说明:
|
||||
|
@ -25,7 +26,7 @@ func TestVLMPlanning(t *testing.T) {
|
|||
5. 得分机制: 每成功连接并消除一对图案,玩家会获得相应的分数。完成游戏后,根据剩余时间和消除效率计算总分。
|
||||
6. 关卡设计: 游戏可能包含多个关卡,随着关卡的推进,图案的复杂度和数量会增加。`
|
||||
|
||||
userInstruction += "\n\n请基于以上游戏规则,请依次点击两个可消除的相同图案"
|
||||
userInstruction += "\n\n请基于以上游戏规则,给出下一步可点击的两个图标坐标"
|
||||
|
||||
planner, err := NewPlanner(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -45,6 +46,7 @@ func TestVLMPlanning(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Size: size,
|
||||
}
|
||||
|
||||
// 执行规划
|
||||
|
@ -62,7 +64,7 @@ func TestVLMPlanning(t *testing.T) {
|
|||
|
||||
// 根据动作类型验证参数
|
||||
switch action.ActionType {
|
||||
case ActionTypeClick:
|
||||
case "click", "drag", "left_double", "right_single", "scroll":
|
||||
// 这些动作需要验证坐标
|
||||
assert.NotEmpty(t, action.ActionInputs["startBox"])
|
||||
|
||||
|
@ -76,6 +78,14 @@ func TestVLMPlanning(t *testing.T) {
|
|||
assert.GreaterOrEqual(t, coord, float64(0))
|
||||
}
|
||||
|
||||
case "type":
|
||||
// 验证文本内容
|
||||
assert.NotEmpty(t, action.ActionInputs["content"])
|
||||
|
||||
case "hotkey":
|
||||
// 验证按键
|
||||
assert.NotEmpty(t, action.ActionInputs["key"])
|
||||
|
||||
case "wait", "finished", "call_user":
|
||||
// 这些动作不需要额外参数
|
||||
|
||||
|
@ -85,10 +95,13 @@ func TestVLMPlanning(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestXHSPlanning(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/xhs-feed.jpeg")
|
||||
err := loadEnv()
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := `点击第二个帖子的作者头像`
|
||||
imageBase64, size, err := loadImage("testdata/xhs-feed.jpeg")
|
||||
require.NoError(t, err)
|
||||
|
||||
userInstruction := "点击第二个帖子的作者头像"
|
||||
|
||||
planner, err := NewPlanner(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -108,6 +121,7 @@ func TestXHSPlanning(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Size: size,
|
||||
}
|
||||
|
||||
// 执行规划
|
||||
|
@ -122,10 +136,41 @@ func TestXHSPlanning(t *testing.T) {
|
|||
action := result.NextActions[0]
|
||||
assert.NotEmpty(t, action.ActionType)
|
||||
assert.NotEmpty(t, action.Thought)
|
||||
|
||||
// 根据动作类型验证参数
|
||||
switch action.ActionType {
|
||||
case "click", "drag", "left_double", "right_single", "scroll":
|
||||
// 这些动作需要验证坐标
|
||||
assert.NotEmpty(t, action.ActionInputs["startBox"])
|
||||
|
||||
// 验证坐标格式
|
||||
coords, ok := action.ActionInputs["startBox"].([]float64)
|
||||
require.True(t, ok)
|
||||
require.True(t, len(coords) >= 2) // 至少有 x, y 坐标
|
||||
|
||||
// 验证坐标范围
|
||||
for _, coord := range coords {
|
||||
assert.GreaterOrEqual(t, coord, float64(0))
|
||||
}
|
||||
|
||||
case "type":
|
||||
// 验证文本内容
|
||||
assert.NotEmpty(t, action.ActionInputs["content"])
|
||||
|
||||
case "hotkey":
|
||||
// 验证按键
|
||||
assert.NotEmpty(t, action.ActionInputs["key"])
|
||||
|
||||
case "wait", "finished", "call_user":
|
||||
// 这些动作不需要额外参数
|
||||
|
||||
default:
|
||||
t.Fatalf("未知的动作类型: %s", action.ActionType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
imageBase64, size, err := loadImage("testdata/popup_risk_warning.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
|
@ -150,6 +195,7 @@ func TestValidateInput(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Size: size,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
|
@ -163,6 +209,7 @@ func TestValidateInput(t *testing.T) {
|
|||
Content: "",
|
||||
},
|
||||
},
|
||||
Size: size,
|
||||
},
|
||||
wantErr: ErrEmptyInstruction,
|
||||
},
|
||||
|
@ -184,6 +231,7 @@ func TestValidateInput(t *testing.T) {
|
|||
Content: "no image",
|
||||
},
|
||||
},
|
||||
Size: size,
|
||||
},
|
||||
wantErr: ErrInvalidImageData,
|
||||
},
|
||||
|
@ -228,9 +276,13 @@ func TestProcessVLMResponse(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
size := types.Size{
|
||||
Width: 1000,
|
||||
Height: 1000,
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processVLMResponse(tt.actions)
|
||||
result, err := processVLMResponse(tt.actions, size)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
|
@ -245,7 +297,7 @@ func TestProcessVLMResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSavePositionImg(t *testing.T) {
|
||||
imageBase64, err := loadImage("testdata/popup_risk_warning.png")
|
||||
imageBase64, _, err := loadImage("testdata/popup_risk_warning.png")
|
||||
require.NoError(t, err)
|
||||
|
||||
params := struct {
|
||||
|
@ -273,3 +325,19 @@ func TestSavePositionImg(t *testing.T) {
|
|||
// cleanup
|
||||
defer os.Remove(params.OutputPath)
|
||||
}
|
||||
|
||||
func TestLoadImage(t *testing.T) {
|
||||
// Test PNG image
|
||||
pngBase64, pngSize, err := loadImage("testdata/llk_1.png")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, pngBase64)
|
||||
assert.Greater(t, pngSize.Width, 0)
|
||||
assert.Greater(t, pngSize.Height, 0)
|
||||
|
||||
// Test JPEG image
|
||||
jpegBase64, jpegSize, err := loadImage("testdata/xhs-feed.jpeg")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, jpegBase64)
|
||||
assert.Greater(t, jpegSize.Width, 0)
|
||||
assert.Greater(t, jpegSize.Height, 0)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue