httprunner/planner/parser.go

240 lines
7.3 KiB
Go

package planner
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/pkg/errors"
)
// NewActionParser creates a new ActionParser instance
func NewActionParser(factor float64) *ActionParser {
return &ActionParser{
Factor: factor,
}
}
// ActionParser parses VLM responses and converts them to structured actions
type ActionParser struct {
Factor float64 // TODO
}
// Parse parses the prediction text and extracts actions
func (p *ActionParser) Parse(predictionText string) ([]ParsedAction, error) {
// try parsing JSON format, from VLM like GPT-4o
var jsonActions []ParsedAction
jsonActions, jsonErr := p.parseJSON(predictionText)
if jsonErr == nil {
return jsonActions, nil
}
// json parsing failed, try parsing Thought/Action format, from VLM like UI-TARS
thoughtActions, thoughtErr := p.parseThoughtAction(predictionText)
if thoughtErr == nil {
return thoughtActions, nil
}
return nil, fmt.Errorf("no valid actions returned from VLM, jsonErr: %v, thoughtErr: %v", jsonErr, thoughtErr)
}
// parseJSON tries to parse the response as JSON format
func (p *ActionParser) parseJSON(predictionText string) ([]ParsedAction, error) {
predictionText = strings.TrimSpace(predictionText)
if strings.HasPrefix(predictionText, "```json") && strings.HasSuffix(predictionText, "```") {
predictionText = strings.TrimPrefix(predictionText, "```json")
predictionText = strings.TrimSuffix(predictionText, "```")
}
predictionText = strings.TrimSpace(predictionText)
var response VLMResponse
if err := json.Unmarshal([]byte(predictionText), &response); err != nil {
return nil, fmt.Errorf("failed to parse VLM response: %v", err)
}
if response.Error != "" {
return nil, errors.New(response.Error)
}
if len(response.Actions) == 0 {
return nil, errors.New("no actions returned from VLM")
}
// normalize actions
var normalizedActions []ParsedAction
for _, action := range response.Actions {
if err := p.normalizeAction(&action); err != nil {
return nil, errors.Wrap(err, "failed to normalize action")
}
normalizedActions = append(normalizedActions, action)
}
return normalizedActions, nil
}
// parseThoughtAction parses the Thought/Action format response
func (p *ActionParser) parseThoughtAction(predictionText string) ([]ParsedAction, error) {
thoughtRegex := regexp.MustCompile(`(?is)Thought:(.+?)Action:`)
actionRegex := regexp.MustCompile(`(?is)Action:(.+)`)
// extract Thought part
thoughtMatch := thoughtRegex.FindStringSubmatch(predictionText)
var thought string
if len(thoughtMatch) > 1 {
thought = strings.TrimSpace(thoughtMatch[1])
}
// extract Action part, e.g. "click(start_box='(552,454)')"
actionMatch := actionRegex.FindStringSubmatch(predictionText)
if len(actionMatch) < 2 {
return nil, fmt.Errorf("no action found in the response")
}
actionText := strings.TrimSpace(actionMatch[1])
// parse action type and parameters
return p.parseActionText(actionText, thought)
}
// parseActionText parses the action text to extract the action type and parameters
func (p *ActionParser) parseActionText(actionText, thought string) ([]ParsedAction, error) {
// remove trailing comments
if idx := strings.Index(actionText, "#"); idx > 0 {
actionText = strings.TrimSpace(actionText[:idx])
}
// supported action types and regexes
actionRegexes := map[string]*regexp.Regexp{
"click": regexp.MustCompile(`click\(start_box='([^']+)'\)`),
"left_double": regexp.MustCompile(`left_double\(start_box='([^']+)'\)`),
"right_single": regexp.MustCompile(`right_single\(start_box='([^']+)'\)`),
"drag": regexp.MustCompile(`drag\(start_box='([^']+)', end_box='([^']+)'\)`),
"hotkey": regexp.MustCompile(`hotkey\(key='([^']+)'\)`),
"type": regexp.MustCompile(`type\(content='([^']+)'\)`),
"scroll": regexp.MustCompile(`scroll\(start_box='([^']+)', direction='([^']+)'\)`),
"wait": regexp.MustCompile(`wait\(\)`),
"finished": regexp.MustCompile(`finished\(\)`),
"call_user": regexp.MustCompile(`call_user\(\)`),
}
parsedActions := make([]ParsedAction, 0)
for actionType, regex := range actionRegexes {
matches := regex.FindStringSubmatch(actionText)
if len(matches) == 0 {
continue
}
var action ParsedAction
action.ActionType = actionType
action.ActionInputs = make(map[string]interface{})
action.Thought = thought
// parse parameters based on action type
switch actionType {
case "click", "left_double", "right_single":
if len(matches) > 1 {
coord, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize point failed: %s", matches[1])
}
action.ActionInputs["startBox"] = coord
}
case "drag":
if len(matches) > 2 {
// handle start point
startBox, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
}
action.ActionInputs["startBox"] = startBox
// handle end point
endBox, err := p.normalizeCoordinates(matches[2])
if err != nil {
return nil, errors.Wrapf(err, "normalize endBox failed: %s", matches[2])
}
action.ActionInputs["endBox"] = endBox
}
case "hotkey":
if len(matches) > 1 {
action.ActionInputs["key"] = matches[1]
}
case "type":
if len(matches) > 1 {
action.ActionInputs["content"] = matches[1]
}
case "scroll":
if len(matches) > 2 {
startBox, err := p.normalizeCoordinates(matches[1])
if err != nil {
return nil, errors.Wrapf(err, "normalize startBox failed: %s", matches[1])
}
action.ActionInputs["startBox"] = startBox
action.ActionInputs["direction"] = matches[2]
}
case "wait", "finished", "call_user":
// 这些动作没有额外参数
}
parsedActions = append(parsedActions, action)
}
if len(parsedActions) == 0 {
return nil, fmt.Errorf("no valid actions returned from VLM")
}
return parsedActions, nil
}
// normalizeAction normalizes the coordinates in the action
func (p *ActionParser) normalizeAction(action *ParsedAction) error {
switch action.ActionType {
case "click", "drag":
// handle click and drag action coordinates
if startBox, ok := action.ActionInputs["startBox"].(string); ok {
normalized, err := p.normalizeCoordinates(startBox)
if err != nil {
return fmt.Errorf("failed to normalize startBox: %w", err)
}
action.ActionInputs["startBox"] = normalized
}
if endBox, ok := action.ActionInputs["endBox"].(string); ok {
normalized, err := p.normalizeCoordinates(endBox)
if err != nil {
return fmt.Errorf("failed to normalize endBox: %w", err)
}
action.ActionInputs["endBox"] = normalized
}
}
return nil
}
// normalizeCoordinates normalizes the coordinates based on the factor
func (p *ActionParser) normalizeCoordinates(coordStr string) (coords []float64, err error) {
// check empty string
if coordStr == "" {
return nil, fmt.Errorf("empty coordinate string")
}
if !strings.Contains(coordStr, ",") {
return nil, fmt.Errorf("invalid coordinate string: %s", coordStr)
}
// remove possible brackets and split coordinates
coordStr = strings.Trim(coordStr, "[]() \t")
// try parsing JSON array
jsonStr := coordStr
if !strings.HasPrefix(jsonStr, "[") {
jsonStr = "[" + coordStr + "]"
}
err = json.Unmarshal([]byte(jsonStr), &coords)
if err != nil {
return nil, fmt.Errorf("failed to parse coordinate string: %w", err)
}
return coords, nil
}