feat: support to set upstream proxy address (#86)

This commit is contained in:
Rick 2023-06-08 11:14:11 +08:00 committed by GitHub
parent 8da089067c
commit ec32ff386c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 118 additions and 36 deletions

View File

@ -10,3 +10,10 @@ It will start a HTTP proxy server, and set the server address to your browser pr
`atest-collector` will record all HTTP requests which has prefix `/answer/api/v1`, and `atest-collector` will record all HTTP requests which has prefix `/answer/api/v1`, and
save it to file `sample.yaml` once you close the server. save it to file `sample.yaml` once you close the server.
## Features
* Basic authorization
* Upstream proxy
* URL path filter
* Support save response body or not

View File

@ -1,24 +1,33 @@
package cmd package cmd
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"syscall" "syscall"
"github.com/elazarl/goproxy" "github.com/elazarl/goproxy"
"github.com/elazarl/goproxy/ext/auth"
"github.com/linuxsuren/api-testing/extensions/collector/pkg" "github.com/linuxsuren/api-testing/extensions/collector/pkg"
"github.com/linuxsuren/api-testing/extensions/collector/pkg/filter" "github.com/linuxsuren/api-testing/extensions/collector/pkg/filter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
type option struct { type option struct {
port int port int
filterPath string filterPath []string
output string saveResponseBody bool
output string
upstreamProxy string
verbose bool
username string
password string
} }
// NewRootCmd creates the root command // NewRootCmd creates the root command
@ -31,8 +40,13 @@ func NewRootCmd() (c *cobra.Command) {
} }
flags := c.Flags() flags := c.Flags()
flags.IntVarP(&opt.port, "port", "p", 8080, "The port for the proxy") flags.IntVarP(&opt.port, "port", "p", 8080, "The port for the proxy")
flags.StringVarP(&opt.filterPath, "filter-path", "", "", "The path prefix for filtering") flags.StringSliceVarP(&opt.filterPath, "filter-path", "", []string{}, "The path prefix for filtering")
flags.BoolVarP(&opt.saveResponseBody, "save-response-body", "", false, "Save the response body")
flags.StringVarP(&opt.output, "output", "o", "sample.yaml", "The output file") flags.StringVarP(&opt.output, "output", "o", "sample.yaml", "The output file")
flags.StringVarP(&opt.upstreamProxy, "upstream-proxy", "", "", "The upstream proxy")
flags.StringVarP(&opt.username, "username", "", "", "The username for basic auth")
flags.StringVarP(&opt.password, "password", "", "", "The password for basic auth")
flags.BoolVarP(&opt.verbose, "verbose", "", false, "Verbose mode")
_ = cobra.MarkFlagRequired(flags, "filter-path") _ = cobra.MarkFlagRequired(flags, "filter-path")
return return
@ -41,6 +55,7 @@ func NewRootCmd() (c *cobra.Command) {
type responseFilter struct { type responseFilter struct {
urlFilter *filter.URLPathFilter urlFilter *filter.URLPathFilter
collects *pkg.Collects collects *pkg.Collects
ctx context.Context
} }
func (f *responseFilter) filter(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { func (f *responseFilter) filter(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response {
@ -51,7 +66,16 @@ func (f *responseFilter) filter(resp *http.Response, ctx *goproxy.ProxyCtx) *htt
req := resp.Request req := resp.Request
if f.urlFilter.Filter(req.URL) { if f.urlFilter.Filter(req.URL) {
f.collects.Add(req.Clone(context.TODO())) simpleResp := &pkg.SimpleResponse{StatusCode: resp.StatusCode}
if resp.Body != nil {
buf := new(bytes.Buffer)
io.Copy(buf, resp.Body)
simpleResp.Body = buf.String()
resp.Body = io.NopCloser(buf)
}
f.collects.Add(req.Clone(f.ctx), simpleResp)
} }
return resp return resp
} }
@ -59,13 +83,25 @@ func (f *responseFilter) filter(resp *http.Response, ctx *goproxy.ProxyCtx) *htt
func (o *option) runE(cmd *cobra.Command, args []string) (err error) { func (o *option) runE(cmd *cobra.Command, args []string) (err error) {
urlFilter := &filter.URLPathFilter{PathPrefix: o.filterPath} urlFilter := &filter.URLPathFilter{PathPrefix: o.filterPath}
collects := pkg.NewCollects() collects := pkg.NewCollects()
responseFilter := &responseFilter{urlFilter: urlFilter, collects: collects} responseFilter := &responseFilter{urlFilter: urlFilter, collects: collects, ctx: cmd.Context()}
proxy := goproxy.NewProxyHttpServer() proxy := goproxy.NewProxyHttpServer()
proxy.Verbose = true proxy.Verbose = o.verbose
if o.upstreamProxy != "" {
proxy.Tr.Proxy = func(r *http.Request) (*url.URL, error) {
return url.Parse(o.upstreamProxy)
}
proxy.ConnectDial = proxy.NewConnectDialToProxy(o.upstreamProxy)
cmd.Println("Using upstream proxy", o.upstreamProxy)
}
if o.username != "" && o.password != "" {
auth.ProxyBasic(proxy, "my_realm", func(user, pwd string) bool {
return user == o.username && o.password == pwd
})
}
proxy.OnResponse().DoFunc(responseFilter.filter) proxy.OnResponse().DoFunc(responseFilter.filter)
exporter := pkg.NewSampleExporter() exporter := pkg.NewSampleExporter(o.saveResponseBody)
collects.AddEvent(exporter.Add) collects.AddEvent(exporter.Add)
srv := &http.Server{ srv := &http.Server{

View File

@ -1,6 +1,9 @@
package cmd package cmd
import ( import (
"bytes"
"context"
"io"
"net/http" "net/http"
"net/url" "net/url"
"testing" "testing"
@ -17,19 +20,26 @@ func TestNewRootCmd(t *testing.T) {
} }
func TestResponseFilter(t *testing.T) { func TestResponseFilter(t *testing.T) {
targetURL, err := url.Parse("http://foo.com/api/v1")
assert.NoError(t, err)
resp := &http.Response{ resp := &http.Response{
Header: http.Header{ Header: http.Header{
"Content-Type": []string{"application/json; charset=utf-8"}, "Content-Type": []string{"application/json; charset=utf-8"},
}, },
Request: &http.Request{ Request: &http.Request{
URL: &url.URL{}, URL: targetURL,
}, },
Body: io.NopCloser(bytes.NewBuffer([]byte("hello"))),
} }
emptyResp := &http.Response{} emptyResp := &http.Response{}
filter := &responseFilter{ filter := &responseFilter{
urlFilter: &filter.URLPathFilter{}, urlFilter: &filter.URLPathFilter{
collects: pkg.NewCollects(), PathPrefix: []string{"/api/v1"},
},
collects: pkg.NewCollects(),
ctx: context.Background(),
} }
filter.filter(emptyResp, nil) filter.filter(emptyResp, nil)
filter.filter(resp, nil) filter.filter(resp, nil)

View File

@ -11,33 +11,46 @@ type Collects struct {
once sync.Once once sync.Once
signal chan string signal chan string
stopSignal chan struct{} stopSignal chan struct{}
keys map[string]*http.Request keys map[string]*RequestAndResponse
requests []*http.Request requests []*http.Request
events []EventHandle events []EventHandle
} }
type SimpleResponse struct {
StatusCode int
Body string
}
type RequestAndResponse struct {
Request *http.Request
Response *SimpleResponse
}
// NewCollects creates an instance of Collector // NewCollects creates an instance of Collector
func NewCollects() *Collects { func NewCollects() *Collects {
return &Collects{ return &Collects{
once: sync.Once{}, once: sync.Once{},
signal: make(chan string, 5), signal: make(chan string, 5),
stopSignal: make(chan struct{}, 1), stopSignal: make(chan struct{}, 1),
keys: make(map[string]*http.Request), keys: make(map[string]*RequestAndResponse),
} }
} }
// Add adds a HTTP request // Add adds a HTTP request
func (c *Collects) Add(req *http.Request) { func (c *Collects) Add(req *http.Request, resp *SimpleResponse) {
key := fmt.Sprintf("%s-%s", req.Method, req.URL.String()) key := fmt.Sprintf("%s-%s", req.Method, req.URL.String())
if _, ok := c.keys[key]; !ok { if _, ok := c.keys[key]; !ok {
c.keys[key] = req c.keys[key] = &RequestAndResponse{
Request: req,
Response: resp,
}
c.requests = append(c.requests, req) c.requests = append(c.requests, req)
c.signal <- key c.signal <- key
} }
} }
// EventHandle is the collect event handle // EventHandle is the collect event handle
type EventHandle func(r *http.Request) type EventHandle func(r *RequestAndResponse)
// AddEvent adds new event handle // AddEvent adds new event handle
func (c *Collects) AddEvent(e EventHandle) { func (c *Collects) AddEvent(e EventHandle) {
@ -60,7 +73,6 @@ func (c *Collects) handleEvents() {
case key := <-c.signal: case key := <-c.signal:
fmt.Println("receive signal", key) fmt.Println("receive signal", key)
for _, e := range c.events { for _, e := range c.events {
fmt.Println("handle event", key, e)
e(c.keys[key]) e(c.keys[key])
} }
case <-c.stopSignal: case <-c.stopSignal:

View File

@ -21,11 +21,12 @@ func TestCollector(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
collects := pkg.NewCollects() collects := pkg.NewCollects()
collects.AddEvent(func(r *http.Request) { collects.AddEvent(func(reqAndResp *pkg.RequestAndResponse) {
r := reqAndResp.Request
assert.Equal(t, tt.Request, r) assert.Equal(t, tt.Request, r)
}) })
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
collects.Add(tt.Request) collects.Add(tt.Request, nil)
} }
collects.Stop() collects.Stop()
}) })

View File

@ -3,7 +3,6 @@ package pkg
import ( import (
"fmt" "fmt"
"io" "io"
"net/http"
"strings" "strings"
atestpkg "github.com/linuxsuren/api-testing/pkg/testing" atestpkg "github.com/linuxsuren/api-testing/pkg/testing"
@ -12,20 +11,23 @@ import (
// SampleExporter is a sample exporter // SampleExporter is a sample exporter
type SampleExporter struct { type SampleExporter struct {
TestSuite atestpkg.TestSuite TestSuite atestpkg.TestSuite
saveResponseBody bool
} }
// NewSampleExporter creates a new exporter // NewSampleExporter creates a new exporter
func NewSampleExporter() *SampleExporter { func NewSampleExporter(saveResponseBody bool) *SampleExporter {
return &SampleExporter{ return &SampleExporter{
TestSuite: atestpkg.TestSuite{ TestSuite: atestpkg.TestSuite{
Name: "sample", Name: "sample",
}, },
saveResponseBody: saveResponseBody,
} }
} }
// Add adds a request to the exporter // Add adds a request to the exporter
func (e *SampleExporter) Add(r *http.Request) { func (e *SampleExporter) Add(reqAndResp *RequestAndResponse) {
r, resp := reqAndResp.Request, reqAndResp.Response
fmt.Println("receive", r.URL.Path) fmt.Println("receive", r.URL.Path)
req := atestpkg.Request{ req := atestpkg.Request{
@ -42,9 +44,13 @@ func (e *SampleExporter) Add(r *http.Request) {
testCase := atestpkg.TestCase{ testCase := atestpkg.TestCase{
Request: req, Request: req,
Expect: atestpkg.Response{ }
StatusCode: http.StatusOK,
}, if resp != nil {
testCase.Expect.StatusCode = resp.StatusCode
if e.saveResponseBody && resp.Body != "" {
testCase.Expect.Body = resp.Body
}
} }
specs := strings.Split(r.URL.Path, "/") specs := strings.Split(r.URL.Path, "/")

View File

@ -12,15 +12,21 @@ import (
) )
func TestSampleExporter(t *testing.T) { func TestSampleExporter(t *testing.T) {
exporter := pkg.NewSampleExporter() exporter := pkg.NewSampleExporter(true)
assert.Equal(t, "sample", exporter.TestSuite.Name) assert.Equal(t, "sample", exporter.TestSuite.Name)
request, err := newRequest() request, err := newRequest()
assert.NoError(t, err) assert.NoError(t, err)
exporter.Add(request) exporter.Add(&pkg.RequestAndResponse{Request: request})
request, err = newRequest() request, err = newRequest()
exporter.Add(request) exporter.Add(&pkg.RequestAndResponse{
Request: request,
Response: &pkg.SimpleResponse{
Body: "hello",
StatusCode: http.StatusOK,
},
})
var result string var result string
result, err = exporter.Export() result, err = exporter.Export()

View File

@ -1,7 +1,6 @@
package filter package filter
import ( import (
"fmt"
"net/url" "net/url"
"strings" "strings"
) )
@ -13,11 +12,15 @@ type URLFilter interface {
// URLPathFilter filters the URL with path // URLPathFilter filters the URL with path
type URLPathFilter struct { type URLPathFilter struct {
PathPrefix string PathPrefix []string
} }
// Filter implements the URLFilter // Filter implements the URLFilter
func (f *URLPathFilter) Filter(targetURL *url.URL) bool { func (f *URLPathFilter) Filter(targetURL *url.URL) bool {
fmt.Println(targetURL.Path, f.PathPrefix) for _, prefix := range f.PathPrefix {
return strings.HasPrefix(targetURL.Path, f.PathPrefix) if strings.HasPrefix(targetURL.Path, prefix) {
return true
}
}
return false
} }

View File

@ -9,6 +9,8 @@ import (
) )
func TestURLPathFilter(t *testing.T) { func TestURLPathFilter(t *testing.T) {
urlFilter := &filter.URLPathFilter{PathPrefix: "/api"} urlFilter := &filter.URLPathFilter{PathPrefix: []string{"/api/v1", "/api/v2"}}
assert.True(t, urlFilter.Filter(&url.URL{Path: "/api/v1"})) assert.True(t, urlFilter.Filter(&url.URL{Path: "/api/v1"}))
assert.True(t, urlFilter.Filter(&url.URL{Path: "/api/v2"}))
assert.False(t, urlFilter.Filter(&url.URL{Path: "/api/v3"}))
} }

View File

@ -10,8 +10,6 @@ items:
Authorization: Bearer token Authorization: Bearer token
Content-Type: application/json Content-Type: application/json
body: hello body: hello
expect:
statusCode: 200
- name: v1-1 - name: v1-1
request: request:
api: http://foo/api/v1 api: http://foo/api/v1
@ -22,3 +20,4 @@ items:
body: hello body: hello
expect: expect:
statusCode: 200 statusCode: 200
body: hello