Bläddra i källkod

Add base tests for proxy

tags/v2.0.0-rc1
9seconds 5 år sedan
förälder
incheckning
2cdff017e2
2 ändrade filer med 212 tillägg och 12 borttagningar
  1. 28
    12
      mtglib/proxy.go
  2. 184
    0
      mtglib/proxy_test.go

+ 28
- 12
mtglib/proxy.go Visa fil

@@ -42,6 +42,9 @@ func (p *Proxy) DomainFrontingAddress() string {
42 42
 }
43 43
 
44 44
 func (p *Proxy) ServeConn(conn net.Conn) {
45
+	p.streamWaitGroup.Add(1)
46
+	defer p.streamWaitGroup.Done()
47
+
45 48
 	ctx := newStreamContext(p.ctx, p.logger, conn)
46 49
 	defer ctx.Close()
47 50
 
@@ -91,20 +94,24 @@ func (p *Proxy) ServeConn(conn net.Conn) {
91 94
 }
92 95
 
93 96
 func (p *Proxy) Serve(listener net.Listener) error {
97
+	p.streamWaitGroup.Add(1)
98
+	defer p.streamWaitGroup.Done()
99
+
94 100
 	for {
95 101
 		conn, err := listener.Accept()
96 102
 		if err != nil {
97 103
 			return fmt.Errorf("cannot accept a new connection: %w", err)
98 104
 		}
99 105
 
100
-		if addr := conn.RemoteAddr().(*net.TCPAddr).IP; p.ipBlocklist.Contains(addr) {
106
+		ipAddr := conn.RemoteAddr().(*net.TCPAddr).IP
107
+		logger := p.logger.BindStr("ip", ipAddr.String())
108
+
109
+		if p.ipBlocklist.Contains(ipAddr) {
101 110
 			conn.Close()
102
-			p.logger.
103
-				BindStr("ip", conn.RemoteAddr().(*net.TCPAddr).IP.String()).
104
-				Info("ip was blacklisted")
111
+			logger.Info("ip was blacklisted")
105 112
 			p.eventStream.Send(p.ctx, EventIPBlocklisted{
106 113
 				CreatedAt: time.Now(),
107
-				RemoteIP:  addr,
114
+				RemoteIP:  ipAddr,
108 115
 			})
109 116
 
110 117
 			continue
@@ -117,13 +124,17 @@ func (p *Proxy) Serve(listener net.Listener) error {
117 124
 		case errors.Is(err, ants.ErrPoolClosed):
118 125
 			return nil
119 126
 		case errors.Is(err, ants.ErrPoolOverload):
120
-			p.logger.
121
-				BindStr("ip", conn.RemoteAddr().(*net.TCPAddr).IP.String()).
122
-				Info("connection was concurrency limited")
127
+			logger.Info("connection was concurrency limited")
123 128
 			p.eventStream.Send(p.ctx, EventConcurrencyLimited{
124 129
 				CreatedAt: time.Now(),
125 130
 			})
126 131
 		}
132
+
133
+		select {
134
+		case <-p.ctx.Done():
135
+			return p.ctx.Err()
136
+		default:
137
+		}
127 138
 	}
128 139
 }
129 140
 
@@ -292,9 +303,9 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) { // nolint: cyclop, funlen
292 303
 		return nil, ErrSecretInvalid
293 304
 	}
294 305
 
295
-	tg, err := telegram.New(opts.Network, opts.PreferIP)
296
-	if err != nil {
297
-		return nil, fmt.Errorf("cannot build telegram dialer: %w", err)
306
+	preferIP := opts.PreferIP
307
+	if preferIP == "" {
308
+		preferIP = DefaultPreferIP
298 309
 	}
299 310
 
300 311
 	concurrency := opts.Concurrency
@@ -317,6 +328,11 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) { // nolint: cyclop, funlen
317 328
 		domainFrontingPort = DefaultDomainFrontingPort
318 329
 	}
319 330
 
331
+	tg, err := telegram.New(opts.Network, preferIP)
332
+	if err != nil {
333
+		return nil, fmt.Errorf("cannot build telegram dialer: %w", err)
334
+	}
335
+
320 336
 	ctx, cancel := context.WithCancel(context.Background())
321 337
 	proxy := &Proxy{
322 338
 		ctx:                ctx,
@@ -340,7 +356,7 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) { // nolint: cyclop, funlen
340 356
 		ants.WithLogger(opts.Logger.Named("ants")),
341 357
 		ants.WithNonblocking(true))
342 358
 	if err != nil {
343
-		return nil, fmt.Errorf("cannot initialize a pool: %w", err)
359
+		panic(err)
344 360
 	}
345 361
 
346 362
 	proxy.workerPool = pool

+ 184
- 0
mtglib/proxy_test.go Visa fil

@@ -0,0 +1,184 @@
1
+package mtglib_test
2
+
3
+import (
4
+	"crypto/tls"
5
+	"encoding/json"
6
+	"fmt"
7
+	"io"
8
+	"net"
9
+	"net/http"
10
+	"testing"
11
+	"time"
12
+
13
+	"github.com/9seconds/mtg/v2/antireplay"
14
+	"github.com/9seconds/mtg/v2/events"
15
+	"github.com/9seconds/mtg/v2/ipblocklist"
16
+	"github.com/9seconds/mtg/v2/logger"
17
+	"github.com/9seconds/mtg/v2/mtglib"
18
+	"github.com/9seconds/mtg/v2/network"
19
+	"github.com/9seconds/mtg/v2/timeattack"
20
+	"github.com/stretchr/testify/suite"
21
+)
22
+
23
+type ProxyTestSuite struct {
24
+	suite.Suite
25
+
26
+	opts     *mtglib.ProxyOpts
27
+	p        *mtglib.Proxy
28
+	listener net.Listener
29
+}
30
+
31
+func (suite *ProxyTestSuite) ProxyAddress() string {
32
+	_, port, _ := net.SplitHostPort(suite.listener.Addr().String())
33
+
34
+	return net.JoinHostPort("127.0.0.1", port)
35
+}
36
+
37
+func (suite *ProxyTestSuite) ProxySecret() string {
38
+	return suite.opts.Secret.Hex()
39
+}
40
+
41
+func (suite *ProxyTestSuite) SetupSuite() {
42
+	dialer, err := network.NewDefaultDialer(0, 0)
43
+	suite.NoError(err)
44
+
45
+	ntw, err := network.NewNetwork(dialer, "mtgtest", "1.1.1.1", 0)
46
+	suite.NoError(err)
47
+
48
+	suite.opts = &mtglib.ProxyOpts{
49
+		Secret:             mtglib.GenerateSecret("httpbin.org"),
50
+		Network:            ntw,
51
+		AntiReplayCache:    antireplay.NewNoop(),
52
+		TimeAttackDetector: timeattack.NewNoop(),
53
+		IPBlocklist:        ipblocklist.NewNoop(),
54
+		EventStream:        events.NewNoopStream(),
55
+		Logger:             logger.NewNoopLogger(),
56
+	}
57
+
58
+	proxy, err := mtglib.NewProxy(*suite.opts)
59
+	suite.NoError(err)
60
+
61
+	suite.p = proxy
62
+
63
+	listener, err := net.Listen("tcp", ":0")
64
+	suite.NoError(err)
65
+
66
+	suite.listener = listener
67
+
68
+	go suite.p.Serve(suite.listener) // nolint: errcheck
69
+}
70
+
71
+func (suite *ProxyTestSuite) TearDownSuite() {
72
+	if suite.listener != nil {
73
+		suite.listener.Close()
74
+	}
75
+
76
+	if suite.p != nil {
77
+		suite.p.Shutdown()
78
+	}
79
+}
80
+
81
+func (suite *ProxyTestSuite) TestCannotInitNoSecret() {
82
+	opts := *suite.opts
83
+	opts.Secret = mtglib.Secret{}
84
+
85
+	_, err := mtglib.NewProxy(opts)
86
+	suite.Error(err)
87
+}
88
+
89
+func (suite *ProxyTestSuite) TestCannotInitNoNetwork() {
90
+	opts := *suite.opts
91
+	opts.Network = nil
92
+
93
+	_, err := mtglib.NewProxy(opts)
94
+	suite.Error(err)
95
+}
96
+
97
+func (suite *ProxyTestSuite) TestCannotInitNoAntiReplayCache() {
98
+	opts := *suite.opts
99
+	opts.AntiReplayCache = nil
100
+
101
+	_, err := mtglib.NewProxy(opts)
102
+	suite.Error(err)
103
+}
104
+
105
+func (suite *ProxyTestSuite) TestCannotInitNoIPBlocklist() {
106
+	opts := *suite.opts
107
+	opts.IPBlocklist = nil
108
+
109
+	_, err := mtglib.NewProxy(opts)
110
+	suite.Error(err)
111
+}
112
+
113
+func (suite *ProxyTestSuite) TestCannotInitNoEventStream() {
114
+	opts := *suite.opts
115
+	opts.EventStream = nil
116
+
117
+	_, err := mtglib.NewProxy(opts)
118
+	suite.Error(err)
119
+}
120
+
121
+func (suite *ProxyTestSuite) TestCannotInitNoTimeAttackDetector() {
122
+	opts := *suite.opts
123
+	opts.TimeAttackDetector = nil
124
+
125
+	_, err := mtglib.NewProxy(opts)
126
+	suite.Error(err)
127
+}
128
+
129
+func (suite *ProxyTestSuite) TestCannotInitNoLogger() {
130
+	opts := *suite.opts
131
+	opts.Logger = nil
132
+
133
+	_, err := mtglib.NewProxy(opts)
134
+	suite.Error(err)
135
+}
136
+
137
+func (suite *ProxyTestSuite) TestCannotInitIncorrectPreferIP() {
138
+	opts := *suite.opts
139
+	opts.PreferIP = "xxx"
140
+
141
+	_, err := mtglib.NewProxy(opts)
142
+	suite.Error(err)
143
+}
144
+
145
+func (suite *ProxyTestSuite) TestDomainFrontingAddress() {
146
+	suite.Equal("httpbin.org:443", suite.p.DomainFrontingAddress())
147
+}
148
+
149
+func (suite *ProxyTestSuite) TestHTTPSRequest() {
150
+	client := &http.Client{
151
+		Transport: &http.Transport{
152
+			TLSClientConfig: &tls.Config{
153
+				InsecureSkipVerify: true,
154
+			},
155
+		},
156
+		Timeout: 5 * time.Second,
157
+	}
158
+
159
+	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
160
+
161
+	resp, err := client.Get(addr) // nolint: noctx
162
+	suite.NoError(err)
163
+
164
+	defer resp.Body.Close()
165
+
166
+	suite.Equal(http.StatusOK, resp.StatusCode)
167
+
168
+	data, err := io.ReadAll(resp.Body)
169
+	suite.NoError(err)
170
+
171
+	jsonStruct := struct {
172
+		Headers struct {
173
+			TraceID string `json:"X-Amzn-Trace-Id"`
174
+		} `json:"headers"`
175
+	}{}
176
+
177
+	suite.NoError(json.Unmarshal(data, &jsonStruct))
178
+	suite.NotEmpty(jsonStruct.Headers.TraceID)
179
+}
180
+
181
+func TestProxy(t *testing.T) {
182
+	t.Parallel()
183
+	suite.Run(t, &ProxyTestSuite{})
184
+}

Laddar…
Avbryt
Spara