Explorar el Código

Add base tests for proxy

tags/v2.0.0-rc1
9seconds hace 5 años
padre
commit
2cdff017e2
Se han modificado 2 ficheros con 212 adiciones y 12 borrados
  1. 28
    12
      mtglib/proxy.go
  2. 184
    0
      mtglib/proxy_test.go

+ 28
- 12
mtglib/proxy.go Ver fichero

42
 }
42
 }
43
 
43
 
44
 func (p *Proxy) ServeConn(conn net.Conn) {
44
 func (p *Proxy) ServeConn(conn net.Conn) {
45
+	p.streamWaitGroup.Add(1)
46
+	defer p.streamWaitGroup.Done()
47
+
45
 	ctx := newStreamContext(p.ctx, p.logger, conn)
48
 	ctx := newStreamContext(p.ctx, p.logger, conn)
46
 	defer ctx.Close()
49
 	defer ctx.Close()
47
 
50
 
91
 }
94
 }
92
 
95
 
93
 func (p *Proxy) Serve(listener net.Listener) error {
96
 func (p *Proxy) Serve(listener net.Listener) error {
97
+	p.streamWaitGroup.Add(1)
98
+	defer p.streamWaitGroup.Done()
99
+
94
 	for {
100
 	for {
95
 		conn, err := listener.Accept()
101
 		conn, err := listener.Accept()
96
 		if err != nil {
102
 		if err != nil {
97
 			return fmt.Errorf("cannot accept a new connection: %w", err)
103
 			return fmt.Errorf("cannot accept a new connection: %w", err)
98
 		}
104
 		}
99
 
105
 
100
-		if addr := conn.RemoteAddr().(*net.TCPAddr).IP; p.ipBlocklist.Contains(addr) {
106
+		ipAddr := conn.RemoteAddr().(*net.TCPAddr).IP
107
+		logger := p.logger.BindStr("ip", ipAddr.String())
108
+
109
+		if p.ipBlocklist.Contains(ipAddr) {
101
 			conn.Close()
110
 			conn.Close()
102
-			p.logger.
103
-				BindStr("ip", conn.RemoteAddr().(*net.TCPAddr).IP.String()).
104
-				Info("ip was blacklisted")
111
+			logger.Info("ip was blacklisted")
105
 			p.eventStream.Send(p.ctx, EventIPBlocklisted{
112
 			p.eventStream.Send(p.ctx, EventIPBlocklisted{
106
 				CreatedAt: time.Now(),
113
 				CreatedAt: time.Now(),
107
-				RemoteIP:  addr,
114
+				RemoteIP:  ipAddr,
108
 			})
115
 			})
109
 
116
 
110
 			continue
117
 			continue
117
 		case errors.Is(err, ants.ErrPoolClosed):
124
 		case errors.Is(err, ants.ErrPoolClosed):
118
 			return nil
125
 			return nil
119
 		case errors.Is(err, ants.ErrPoolOverload):
126
 		case errors.Is(err, ants.ErrPoolOverload):
120
-			p.logger.
121
-				BindStr("ip", conn.RemoteAddr().(*net.TCPAddr).IP.String()).
122
-				Info("connection was concurrency limited")
127
+			logger.Info("connection was concurrency limited")
123
 			p.eventStream.Send(p.ctx, EventConcurrencyLimited{
128
 			p.eventStream.Send(p.ctx, EventConcurrencyLimited{
124
 				CreatedAt: time.Now(),
129
 				CreatedAt: time.Now(),
125
 			})
130
 			})
126
 		}
131
 		}
132
+
133
+		select {
134
+		case <-p.ctx.Done():
135
+			return p.ctx.Err()
136
+		default:
137
+		}
127
 	}
138
 	}
128
 }
139
 }
129
 
140
 
292
 		return nil, ErrSecretInvalid
303
 		return nil, ErrSecretInvalid
293
 	}
304
 	}
294
 
305
 
295
-	tg, err := telegram.New(opts.Network, opts.PreferIP)
296
-	if err != nil {
297
-		return nil, fmt.Errorf("cannot build telegram dialer: %w", err)
306
+	preferIP := opts.PreferIP
307
+	if preferIP == "" {
308
+		preferIP = DefaultPreferIP
298
 	}
309
 	}
299
 
310
 
300
 	concurrency := opts.Concurrency
311
 	concurrency := opts.Concurrency
317
 		domainFrontingPort = DefaultDomainFrontingPort
328
 		domainFrontingPort = DefaultDomainFrontingPort
318
 	}
329
 	}
319
 
330
 
331
+	tg, err := telegram.New(opts.Network, preferIP)
332
+	if err != nil {
333
+		return nil, fmt.Errorf("cannot build telegram dialer: %w", err)
334
+	}
335
+
320
 	ctx, cancel := context.WithCancel(context.Background())
336
 	ctx, cancel := context.WithCancel(context.Background())
321
 	proxy := &Proxy{
337
 	proxy := &Proxy{
322
 		ctx:                ctx,
338
 		ctx:                ctx,
340
 		ants.WithLogger(opts.Logger.Named("ants")),
356
 		ants.WithLogger(opts.Logger.Named("ants")),
341
 		ants.WithNonblocking(true))
357
 		ants.WithNonblocking(true))
342
 	if err != nil {
358
 	if err != nil {
343
-		return nil, fmt.Errorf("cannot initialize a pool: %w", err)
359
+		panic(err)
344
 	}
360
 	}
345
 
361
 
346
 	proxy.workerPool = pool
362
 	proxy.workerPool = pool

+ 184
- 0
mtglib/proxy_test.go Ver fichero

1
+package mtglib_test
2
+
3
+import (
4
+	"crypto/tls"
5
+	"encoding/json"
6
+	"fmt"
7
+	"io"
8
+	"net"
9
+	"net/http"
10
+	"testing"
11
+	"time"
12
+
13
+	"github.com/9seconds/mtg/v2/antireplay"
14
+	"github.com/9seconds/mtg/v2/events"
15
+	"github.com/9seconds/mtg/v2/ipblocklist"
16
+	"github.com/9seconds/mtg/v2/logger"
17
+	"github.com/9seconds/mtg/v2/mtglib"
18
+	"github.com/9seconds/mtg/v2/network"
19
+	"github.com/9seconds/mtg/v2/timeattack"
20
+	"github.com/stretchr/testify/suite"
21
+)
22
+
23
+type ProxyTestSuite struct {
24
+	suite.Suite
25
+
26
+	opts     *mtglib.ProxyOpts
27
+	p        *mtglib.Proxy
28
+	listener net.Listener
29
+}
30
+
31
+func (suite *ProxyTestSuite) ProxyAddress() string {
32
+	_, port, _ := net.SplitHostPort(suite.listener.Addr().String())
33
+
34
+	return net.JoinHostPort("127.0.0.1", port)
35
+}
36
+
37
+func (suite *ProxyTestSuite) ProxySecret() string {
38
+	return suite.opts.Secret.Hex()
39
+}
40
+
41
+func (suite *ProxyTestSuite) SetupSuite() {
42
+	dialer, err := network.NewDefaultDialer(0, 0)
43
+	suite.NoError(err)
44
+
45
+	ntw, err := network.NewNetwork(dialer, "mtgtest", "1.1.1.1", 0)
46
+	suite.NoError(err)
47
+
48
+	suite.opts = &mtglib.ProxyOpts{
49
+		Secret:             mtglib.GenerateSecret("httpbin.org"),
50
+		Network:            ntw,
51
+		AntiReplayCache:    antireplay.NewNoop(),
52
+		TimeAttackDetector: timeattack.NewNoop(),
53
+		IPBlocklist:        ipblocklist.NewNoop(),
54
+		EventStream:        events.NewNoopStream(),
55
+		Logger:             logger.NewNoopLogger(),
56
+	}
57
+
58
+	proxy, err := mtglib.NewProxy(*suite.opts)
59
+	suite.NoError(err)
60
+
61
+	suite.p = proxy
62
+
63
+	listener, err := net.Listen("tcp", ":0")
64
+	suite.NoError(err)
65
+
66
+	suite.listener = listener
67
+
68
+	go suite.p.Serve(suite.listener) // nolint: errcheck
69
+}
70
+
71
+func (suite *ProxyTestSuite) TearDownSuite() {
72
+	if suite.listener != nil {
73
+		suite.listener.Close()
74
+	}
75
+
76
+	if suite.p != nil {
77
+		suite.p.Shutdown()
78
+	}
79
+}
80
+
81
+func (suite *ProxyTestSuite) TestCannotInitNoSecret() {
82
+	opts := *suite.opts
83
+	opts.Secret = mtglib.Secret{}
84
+
85
+	_, err := mtglib.NewProxy(opts)
86
+	suite.Error(err)
87
+}
88
+
89
+func (suite *ProxyTestSuite) TestCannotInitNoNetwork() {
90
+	opts := *suite.opts
91
+	opts.Network = nil
92
+
93
+	_, err := mtglib.NewProxy(opts)
94
+	suite.Error(err)
95
+}
96
+
97
+func (suite *ProxyTestSuite) TestCannotInitNoAntiReplayCache() {
98
+	opts := *suite.opts
99
+	opts.AntiReplayCache = nil
100
+
101
+	_, err := mtglib.NewProxy(opts)
102
+	suite.Error(err)
103
+}
104
+
105
+func (suite *ProxyTestSuite) TestCannotInitNoIPBlocklist() {
106
+	opts := *suite.opts
107
+	opts.IPBlocklist = nil
108
+
109
+	_, err := mtglib.NewProxy(opts)
110
+	suite.Error(err)
111
+}
112
+
113
+func (suite *ProxyTestSuite) TestCannotInitNoEventStream() {
114
+	opts := *suite.opts
115
+	opts.EventStream = nil
116
+
117
+	_, err := mtglib.NewProxy(opts)
118
+	suite.Error(err)
119
+}
120
+
121
+func (suite *ProxyTestSuite) TestCannotInitNoTimeAttackDetector() {
122
+	opts := *suite.opts
123
+	opts.TimeAttackDetector = nil
124
+
125
+	_, err := mtglib.NewProxy(opts)
126
+	suite.Error(err)
127
+}
128
+
129
+func (suite *ProxyTestSuite) TestCannotInitNoLogger() {
130
+	opts := *suite.opts
131
+	opts.Logger = nil
132
+
133
+	_, err := mtglib.NewProxy(opts)
134
+	suite.Error(err)
135
+}
136
+
137
+func (suite *ProxyTestSuite) TestCannotInitIncorrectPreferIP() {
138
+	opts := *suite.opts
139
+	opts.PreferIP = "xxx"
140
+
141
+	_, err := mtglib.NewProxy(opts)
142
+	suite.Error(err)
143
+}
144
+
145
+func (suite *ProxyTestSuite) TestDomainFrontingAddress() {
146
+	suite.Equal("httpbin.org:443", suite.p.DomainFrontingAddress())
147
+}
148
+
149
+func (suite *ProxyTestSuite) TestHTTPSRequest() {
150
+	client := &http.Client{
151
+		Transport: &http.Transport{
152
+			TLSClientConfig: &tls.Config{
153
+				InsecureSkipVerify: true,
154
+			},
155
+		},
156
+		Timeout: 5 * time.Second,
157
+	}
158
+
159
+	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
160
+
161
+	resp, err := client.Get(addr) // nolint: noctx
162
+	suite.NoError(err)
163
+
164
+	defer resp.Body.Close()
165
+
166
+	suite.Equal(http.StatusOK, resp.StatusCode)
167
+
168
+	data, err := io.ReadAll(resp.Body)
169
+	suite.NoError(err)
170
+
171
+	jsonStruct := struct {
172
+		Headers struct {
173
+			TraceID string `json:"X-Amzn-Trace-Id"`
174
+		} `json:"headers"`
175
+	}{}
176
+
177
+	suite.NoError(json.Unmarshal(data, &jsonStruct))
178
+	suite.NotEmpty(jsonStruct.Headers.TraceID)
179
+}
180
+
181
+func TestProxy(t *testing.T) {
182
+	t.Parallel()
183
+	suite.Run(t, &ProxyTestSuite{})
184
+}

Loading…
Cancelar
Guardar