api-testing/pkg/runner/grpc.go

576 lines
15 KiB
Go

/*
MIT License
Copyright (c) 2023 API Testing Authors.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
package runner
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"regexp"
"strings"
"time"
"github.com/bufbuild/protocompile"
"github.com/linuxsuren/api-testing/pkg/compare"
"github.com/linuxsuren/api-testing/pkg/testing"
"github.com/tidwall/gjson"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection/grpc_reflection_v1"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
)
type gRPCTestCaseRunner struct {
UnimplementedRunner
host string
proto testing.RPCDesc
response SimpleResponse
// fdCache sync.Map
}
var regexFullQualifiedName = regexp.MustCompile(`^([\w\.:]+)\/([\w\.]+)\/(\w+)$`)
var regexURLPrefix = regexp.MustCompile(`^https?://`)
func NewGRPCTestCaseRunner(host string, proto testing.RPCDesc) TestCaseRunner {
runner := &gRPCTestCaseRunner{
UnimplementedRunner: NewDefaultUnimplementedRunner(),
host: host,
proto: proto,
}
return runner
}
func (r *gRPCTestCaseRunner) RunTestCase(testcase *testing.TestCase, dataContext any, ctx context.Context) (output any, err error) {
r.log.Info("start to run: '%s'\n", testcase.Name)
record := NewReportRecord()
defer func(rr *ReportRecord) {
rr.EndTime = time.Now()
rr.Error = err
rr.API = testcase.Request.API
rr.Method = "gRPC"
r.testReporter.PutRecord(rr)
}(record)
defer func() {
if err == nil {
err = runJob(testcase.After, dataContext)
}
}()
contextDir := NewContextKeyBuilder().ParentDir().GetContextValueOrEmpty(ctx)
if err = testcase.Request.Render(dataContext, contextDir); err != nil {
return
}
if err = runJob(testcase.Before, dataContext); err != nil {
return
}
r.log.Info("start to send request to %s\n", testcase.Request.API)
var conn *grpc.ClientConn
if r.Secure == nil || r.Secure.Insecure {
conn, err = grpc.Dial(getHost(testcase.Request.API, r.host), grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
cerd, err := credentials.NewClientTLSFromFile(r.Secure.CertFile, r.Secure.ServerName)
if err != nil {
return nil, err
}
conn, err = grpc.Dial(getHost(testcase.Request.API, r.host), grpc.WithTransportCredentials(cerd))
}
if err != nil {
return nil, err
}
defer conn.Close()
md, err := getMethodDescriptor(ctx, r, testcase, conn)
if err != nil {
if err == protoregistry.NotFound {
return nil, fmt.Errorf("api %q is not found", testcase.Request.API)
}
return nil, err
}
payload := testcase.Request.Body
respsStr, err := invokeRequest(ctx, md, payload, conn)
if err != nil {
return nil, err
}
record.Body = strings.Join(respsStr, ",")
r.log.Debug("response body: %s\n", record.Body)
output, err = verifyResponsePayload(md, testcase.Name, testcase.Expect, respsStr)
if err != nil {
return nil, err
}
return output, nil
}
func (r *gRPCTestCaseRunner) GetResponseRecord() SimpleResponse {
return r.response
}
func invokeRequest(ctx context.Context, md protoreflect.MethodDescriptor, payload string, conn *grpc.ClientConn) (respones []string, err error) {
resps := make([]*dynamicpb.Message, 0)
if md.IsStreamingClient() || md.IsStreamingServer() {
reqs, err := getStreamMessagepb(md.Input(), payload)
if err != nil {
return nil, err
}
resps, err = invokeRPCStream(ctx, conn, md, reqs)
if err != nil {
return nil, err
}
return buildResponses(resps)
}
request, err := getMessagePb(md.Input(), payload)
if err != nil {
return nil, err
}
resp, err := invokeRPC(ctx, conn, md, request)
if err != nil {
return nil, err
}
resps = append(resps, resp)
return buildResponses(resps)
}
func getStreamMessagepb(md protoreflect.MessageDescriptor, messages string) ([]*dynamicpb.Message, error) {
gpayload := gjson.Parse(messages)
var garray []gjson.Result
if !gpayload.IsArray() {
garray = []gjson.Result{gpayload}
} else {
garray = gpayload.Array()
}
reqs := make([]*dynamicpb.Message, len(garray))
for i, v := range garray {
req, err := getMessagePb(md, v.Raw)
if err != nil {
return nil, err
}
reqs[i] = req
}
return reqs, nil
}
func getMessagePb(md protoreflect.MessageDescriptor, message string) (messagepb *dynamicpb.Message, err error) {
request := dynamicpb.NewMessage(md)
if message != "" {
err := protojson.Unmarshal([]byte(message), request)
if err != nil {
return nil, err
}
}
return request, nil
}
func buildResponses(resps []*dynamicpb.Message) ([]string, error) {
respsStr := make([]string, 0)
for i := range resps {
respbR, err := protojson.Marshal(resps[i])
if err != nil {
return nil, err
}
respsStr = append(respsStr, string(respbR))
}
return respsStr, nil
}
func getMethodDescriptor(ctx context.Context, r *gRPCTestCaseRunner, testcase *testing.TestCase, conn *grpc.ClientConn) (protoreflect.MethodDescriptor, error) {
fullname, err := splitFullQualifiedName(testcase.Request.API)
if err != nil {
return nil, err
}
var dp protoreflect.Descriptor
// if fd, ok := r.fdCache.Load(fullname.Parent()); ok {
// fmt.Println("hit cache",fullname)
// return getMdFromFd(fd.(protoreflect.FileDescriptor), fullname)
// }
if r.proto.ServerReflection {
dp, err = getByReflect(ctx, r, fullname, conn)
} else {
if r.proto.ProtoFile == "" && r.proto.ProtoSet == "" && r.proto.Raw == "" {
return nil, fmt.Errorf("missing descriptor source")
}
dp, err = getByProto(ctx, r, fullname)
}
if err != nil {
return nil, err
}
if dp.IsPlaceholder() {
return nil, protoregistry.NotFound
}
if md, ok := dp.(protoreflect.MethodDescriptor); ok {
return md, nil
}
return nil, protoregistry.NotFound
}
func getByProto(ctx context.Context, r *gRPCTestCaseRunner, fullName protoreflect.FullName) (protoreflect.Descriptor, error) {
if r.proto.ProtoSet != "" {
return getByProtoSet(ctx, r, fullName)
}
compiler := protocompile.Compiler{
Resolver: protocompile.WithStandardImports(
&protocompile.SourceResolver{
ImportPaths: r.proto.ImportPath,
},
),
}
// save the proto to a temp file if the raw content given
if r.proto.Raw != "" {
f, err := os.CreateTemp(os.TempDir(), "proto")
if err != nil {
err = fmt.Errorf("failed to create temp file when saving proto content: %v", err)
return nil, err
}
defer os.Remove(f.Name())
_, err = f.WriteString(r.proto.Raw)
if err != nil {
err = fmt.Errorf("failed to write proto content to file %q: %v", f.Name(), err)
return nil, err
}
r.proto.ProtoFile = f.Name()
}
linker, err := compiler.Compile(ctx, r.proto.ProtoFile)
if err != nil {
return nil, err
}
dp, err := linker.AsResolver().FindDescriptorByName(fullName)
if err != nil {
return nil, err
}
// r.fdCache.Store(fullName.Parent(), dp.ParentFile())
return dp, nil
}
func getByProtoSet(ctx context.Context, r *gRPCTestCaseRunner, fullName protoreflect.FullName) (protoreflect.Descriptor, error) {
var decs []byte
var err error
if regexURLPrefix.FindString(r.proto.ProtoSet) != "" {
resp, err := http.Get(r.proto.ProtoSet)
if err != nil {
return nil, err
}
defer resp.Body.Close()
decs, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
} else {
decs, err = os.ReadFile(r.proto.ProtoSet)
if err != nil {
return nil, err
}
}
fds := &descriptorpb.FileDescriptorSet{}
err = proto.Unmarshal(decs, fds)
if err != nil {
return nil, err
}
prfs, err := protodesc.NewFiles(fds)
if err != nil {
return nil, err
}
dp, err := prfs.FindDescriptorByName(fullName)
if err != nil {
return nil, err
}
// r.fdCache.Store(fullName.Parent(), dp.ParentFile())
return dp, nil
}
func getByReflect(ctx context.Context, r *gRPCTestCaseRunner, fullName protoreflect.FullName, conn *grpc.ClientConn) (md protoreflect.Descriptor, err error) {
reflectconn := grpc_reflection_v1.NewServerReflectionClient(conn)
cli, err := reflectconn.ServerReflectionInfo(ctx)
if err != nil {
return nil, err
}
req := &grpc_reflection_v1.ServerReflectionRequest{
Host: "",
MessageRequest: &grpc_reflection_v1.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: string(fullName),
},
}
err = cli.Send(req)
if err != nil {
return nil, err
}
resp, err := cli.Recv()
if err != nil {
return nil, err
}
_ = cli.CloseSend()
if resp := resp.GetErrorResponse(); resp != nil {
return nil, fmt.Errorf(resp.GetErrorMessage())
}
fdresp := resp.GetFileDescriptorResponse()
for _, fdb := range fdresp.FileDescriptorProto {
fdp := &descriptorpb.FileDescriptorProto{}
if err := proto.Unmarshal(fdb, fdp); err != nil {
return nil, err
}
fd, err := protodesc.NewFile(fdp, nil)
if err != nil {
return nil, err
}
md, err = getMdFromFd(fd, fullName)
if err == nil {
// r.fdCache.Store(fullName.Parent(), fd)
return md, nil
}
}
return nil, protoregistry.NotFound
}
func getMdFromFd(fd protoreflect.FileDescriptor, fullname protoreflect.FullName) (md protoreflect.MethodDescriptor, err error) {
sd := fd.Services().ByName(fullname.Parent().Name())
if sd == nil {
return nil, fmt.Errorf("grpc service %q is not found in proto %q", fullname.Parent().Name(), fd.Name())
}
md = sd.Methods().ByName(fullname.Name())
if md == nil {
return nil, fmt.Errorf("method %q is not found in service %q", fullname.Name(), fullname.Parent().Name())
}
return md, nil
}
func splitFullQualifiedName(api string) (protoreflect.FullName, error) {
qn := regexFullQualifiedName.FindStringSubmatch(api)
if len(qn) == 0 {
return "", fmt.Errorf("%q is not a valid gRPC api name", api)
}
fn := protoreflect.FullName(strings.Join(qn[2:], "."))
if !fn.IsValid() {
return "", fmt.Errorf("%q is not a valid gRPC api name", api)
}
return fn, nil
}
func getHost(api, fallback string) (host string) {
qn := regexFullQualifiedName.FindStringSubmatch(api)
if len(qn) == 0 {
return fallback
}
return qn[1]
}
func getMethodName(md protoreflect.MethodDescriptor) string {
return fmt.Sprintf("/%s/%s", md.Parent().FullName(), md.Name())
}
// invokeRPC sends an unary RPC to gRPC server.
func invokeRPC(ctx context.Context, conn grpc.ClientConnInterface, method protoreflect.MethodDescriptor, request *dynamicpb.Message) (resp *dynamicpb.Message, err error) {
resp = dynamicpb.NewMessage(method.Output())
if err := conn.Invoke(ctx, getMethodName(method), request, resp); err != nil {
return nil, err
}
return resp, nil
}
// invokeRPCStream combine all three types of streaming rpc into a single function.
func invokeRPCStream(ctx context.Context, conn grpc.ClientConnInterface, method protoreflect.MethodDescriptor, requests []*dynamicpb.Message) (resps []*dynamicpb.Message, err error) {
sd := &grpc.StreamDesc{
StreamName: string(method.Name()),
ServerStreams: method.IsStreamingServer(),
ClientStreams: method.IsStreamingClient(),
}
s, err := conn.NewStream(ctx, sd, getMethodName(method))
if err != nil {
return nil, err
}
i := 0
sendLoop:
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if len(requests) == i {
break sendLoop
}
if err := s.SendMsg(requests[i]); err != nil {
return nil, err
}
i++
}
}
if err = s.CloseSend(); err != nil {
return nil, err
}
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
resp := dynamicpb.NewMessage(method.Output())
if err = s.RecvMsg(resp); err != nil {
if err == io.EOF {
return resps, nil
}
return nil, err
}
resps = append(resps, resp)
}
}
}
func verifyResponsePayload(md protoreflect.MethodDescriptor, caseName string, expect testing.Response, jsonPayload []string) (output any, err error) {
mapOutput := map[string]any{
"data": func() []map[string]any {
r := make([]map[string]any, len(jsonPayload))
for i := range jsonPayload {
m := map[string]any{}
_ = json.Unmarshal([]byte(jsonPayload[i]), &m)
r[i] = m
}
return r
}(),
}
if err = payloadFieldsVerify(md, caseName, expect, jsonPayload); err != nil {
return
}
err = Verify(expect, mapOutput)
if err != nil {
return nil, err
}
return
}
func payloadFieldsVerify(md protoreflect.MethodDescriptor, caseName string, expect testing.Response, jsonPayload []string) error {
if expect.Body == "" {
return nil
}
if !gjson.Valid(expect.Body) {
return fmt.Errorf("case %q: expect body is not a valid JSON", caseName)
}
exp, err := parseExpect(md, expect)
if err != nil {
return err
}
gjsonPayload := make([]gjson.Result, len(jsonPayload))
for i := range jsonPayload {
gjsonPayload[i] = gjson.Parse(jsonPayload[i])
}
if exp.IsArray() {
return compare.Array(caseName, exp.Array(), gjsonPayload)
}
if exp.IsObject() {
var msg string
for i := range jsonPayload {
err := compare.Object(fmt.Sprintf("%v[%v]", caseName, i),
exp.Map(), gjsonPayload[i].Map())
if err != nil {
msg += err.Error()
}
}
if msg != "" {
return fmt.Errorf(msg)
}
return nil
}
return fmt.Errorf("case %q: unknown expect content", caseName)
}
func parseExpect(md protoreflect.MethodDescriptor, expect testing.Response) (exps gjson.Result, err error) {
b := strings.TrimSpace(expect.Body)
var msgb []byte
if b[0] == '[' {
msgpbs, err := getStreamMessagepb(md.Output(), b)
if err != nil {
return gjson.Result{}, err
}
msgb = append(msgb, '[')
for i := range msgpbs {
msg, _ := protojson.Marshal(msgpbs[i])
msgb = append(msgb, msg...)
msg = append(msg, ',')
}
msgb[len(msgb)-1] = ']'
} else {
msgpb, err := getMessagePb(md.Output(), expect.Body)
if err != nil {
return gjson.Result{}, err
}
msgb, _ = protojson.Marshal(msgpb)
}
return gjson.ParseBytes(msgb), nil
}