Просмотр исходного кода

Merge upstream/master: idiomatic Go, IP priority, flaky test fixes

Upstream changes:
- More idiomatic Golang (b6427ee)
- Change IP address set priority (1fcec38)
- Flaky CI test fixes (eedee63, 73c6a3a, e54d9d6)

Conflicts resolved:
- README.md: keep fork version
- default.pgo: removed (causes SEGV in fork builds)
- proxy.go: keep fork's counting conn, use upstream's tracker
- proxy_opts.go: keep fork's APIBindTo field
pull/434/head
Alexey Dolotov 1 месяц назад
Родитель
Сommit
9399bd5427

+ 1
- 1
antireplay/stable_bloom_filter_test.go Просмотреть файл

@@ -12,7 +12,7 @@ type StableBloomFilterTestSuite struct {
12 12
 }
13 13
 
14 14
 func (suite *StableBloomFilterTestSuite) TestOp() {
15
-	filter := antireplay.NewStableBloomFilter(500, 0.001)
15
+	filter := antireplay.NewStableBloomFilter(100000, 0.001)
16 16
 
17 17
 	suite.False(filter.SeenBefore([]byte{1, 2, 3}))
18 18
 	suite.False(filter.SeenBefore([]byte{4, 5, 6}))

+ 4
- 4
internal/config/config.go Просмотреть файл

@@ -53,10 +53,10 @@ type Config struct {
53 53
 		Blocklist    ListConfig `json:"blocklist"`
54 54
 		Allowlist    ListConfig `json:"allowlist"`
55 55
 		Doppelganger struct {
56
-			URLs            []TypeHttpsURL  `json:"urls"`
57
-			Repeats         TypeConcurrency `json:"repeats_per_raid"`
58
-			UpdateEach      TypeDuration    `json:"raid_each"`
59
-			DRS             TypeBool        `json:"drs"`
56
+			URLs       []TypeHttpsURL  `json:"urls"`
57
+			Repeats    TypeConcurrency `json:"repeats_per_raid"`
58
+			UpdateEach TypeDuration    `json:"raid_each"`
59
+			DRS        TypeBool        `json:"drs"`
60 60
 		} `json:"doppelganger"`
61 61
 	} `json:"defense"`
62 62
 	Network struct {

+ 4
- 4
internal/config/parse.go Просмотреть файл

@@ -48,10 +48,10 @@ type tomlConfig struct {
48 48
 			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
49 49
 		} `toml:"allowlist" json:"allowlist,omitempty"`
50 50
 		Doppelganger struct {
51
-			URLs            []string `toml:"urls" json:"urls,omitempty"`
52
-			Repeats         uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
53
-			UpdateEach      string   `toml:"raid-each" json:"raid_each,omitempty"`
54
-			DRS             bool     `toml:"drs" json:"drs,omitempty"`
51
+			URLs       []string `toml:"urls" json:"urls,omitempty"`
52
+			Repeats    uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
53
+			UpdateEach string   `toml:"raid-each" json:"raid_each,omitempty"`
54
+			DRS        bool     `toml:"drs" json:"drs,omitempty"`
55 55
 		} `toml:"doppelganger" json:"doppelganger,omitempty"`
56 56
 	} `toml:"defense" json:"defense,omitempty"`
57 57
 	Network struct {

+ 50
- 5
mtglib/conns.go Просмотреть файл

@@ -3,9 +3,11 @@ package mtglib
3 3
 import (
4 4
 	"bytes"
5 5
 	"context"
6
+	"errors"
6 7
 	"fmt"
7 8
 	"io"
8 9
 	"net"
10
+	"sync/atomic"
9 11
 	"time"
10 12
 
11 13
 	"github.com/9seconds/mtg/v2/essentials"
@@ -97,20 +99,63 @@ func newConnProxyProtocol(source, target essentials.Conn) *connProxyProtocol {
97 99
 	}
98 100
 }
99 101
 
102
+// idleTracker is a shared idle tracker for a pair of relay connections.
103
+// Both directions update the same timestamp so that activity in one direction
104
+// prevents the other (idle) direction from timing out.
105
+type idleTracker struct {
106
+	lastActive atomic.Pointer[time.Time]
107
+	timeout    time.Duration
108
+}
109
+
110
+func newIdleTracker(timeout time.Duration) *idleTracker {
111
+	t := &idleTracker{timeout: timeout}
112
+	t.touch()
113
+
114
+	return t
115
+}
116
+
117
+func (t *idleTracker) touch() {
118
+	stamp := time.Now()
119
+	t.lastActive.Store(&stamp)
120
+}
121
+
122
+func (t *idleTracker) isIdle() bool {
123
+	return time.Since(*t.lastActive.Load()) >= t.timeout
124
+}
125
+
100 126
 type connIdleTimeout struct {
101 127
 	essentials.Conn
102 128
 
103
-	timeout time.Duration
129
+	tracker *idleTracker
104 130
 }
105 131
 
106 132
 func (c connIdleTimeout) Read(b []byte) (int, error) {
107
-	c.SetReadDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
133
+	var netErr net.Error
134
+
135
+	for {
136
+		c.SetReadDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
108 137
 
109
-	return c.Conn.Read(b) //nolint: wrapcheck
138
+		n, err := c.Conn.Read(b)
139
+
140
+		switch {
141
+		case err == nil:
142
+			c.tracker.touch()
143
+			return n, nil
144
+		case errors.As(err, &netErr) && netErr.Timeout() && !c.tracker.isIdle():
145
+			continue
146
+		}
147
+
148
+		return n, err
149
+	}
110 150
 }
111 151
 
112 152
 func (c connIdleTimeout) Write(b []byte) (int, error) {
113
-	c.SetWriteDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
153
+	c.SetWriteDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
114 154
 
115
-	return c.Conn.Write(b) //nolint: wrapcheck
155
+	n, err := c.Conn.Write(b)
156
+	if n > 0 {
157
+		c.tracker.touch()
158
+	}
159
+
160
+	return n, err //nolint: wrapcheck
116 161
 }

+ 151
- 0
mtglib/conns_internal_test.go Просмотреть файл

@@ -16,6 +16,12 @@ import (
16 16
 	"github.com/stretchr/testify/suite"
17 17
 )
18 18
 
19
+type netTimeoutError struct{}
20
+
21
+func (e netTimeoutError) Error() string   { return "i/o timeout" }
22
+func (e netTimeoutError) Timeout() bool   { return true }
23
+func (e netTimeoutError) Temporary() bool { return true }
24
+
19 25
 type ConnRewindBaseConn struct {
20 26
 	testlib.EssentialsConnMock
21 27
 
@@ -291,6 +297,141 @@ func (suite *ConnProxyProtocolTestSuite) TearDownTest() {
291 297
 	suite.targetConnMock.AssertExpectations(suite.T())
292 298
 }
293 299
 
300
+type IdleTrackerTestSuite struct {
301
+	suite.Suite
302
+}
303
+
304
+func (suite *IdleTrackerTestSuite) TestNewNotIdle() {
305
+	tracker := newIdleTracker(time.Second)
306
+	suite.False(tracker.isIdle())
307
+}
308
+
309
+func (suite *IdleTrackerTestSuite) TestIdleAfterTimeout() {
310
+	tracker := newIdleTracker(10 * time.Millisecond)
311
+	time.Sleep(20 * time.Millisecond)
312
+
313
+	suite.True(tracker.isIdle())
314
+}
315
+
316
+func (suite *IdleTrackerTestSuite) TestTouchResetsIdle() {
317
+	tracker := newIdleTracker(50 * time.Millisecond)
318
+	time.Sleep(30 * time.Millisecond)
319
+
320
+	tracker.touch()
321
+
322
+	suite.False(tracker.isIdle())
323
+}
324
+
325
+type ConnIdleTimeoutTestSuite struct {
326
+	suite.Suite
327
+
328
+	connMock *testlib.EssentialsConnMock
329
+	tracker  *idleTracker
330
+	conn     connIdleTimeout
331
+}
332
+
333
+func (suite *ConnIdleTimeoutTestSuite) SetupTest() {
334
+	suite.connMock = &testlib.EssentialsConnMock{}
335
+	suite.tracker = newIdleTracker(time.Second)
336
+	suite.conn = connIdleTimeout{
337
+		Conn:    suite.connMock,
338
+		tracker: suite.tracker,
339
+	}
340
+}
341
+
342
+func (suite *ConnIdleTimeoutTestSuite) TearDownTest() {
343
+	suite.connMock.AssertExpectations(suite.T())
344
+}
345
+
346
+func (suite *ConnIdleTimeoutTestSuite) TestReadOk() {
347
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
348
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
349
+
350
+	n, err := suite.conn.Read(make([]byte, 10))
351
+	suite.NoError(err)
352
+	suite.Equal(5, n)
353
+}
354
+
355
+func (suite *ConnIdleTimeoutTestSuite) TestReadNonTimeoutErr() {
356
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
357
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
358
+
359
+	n, err := suite.conn.Read(make([]byte, 10))
360
+	suite.True(errors.Is(err, io.EOF))
361
+	suite.Equal(0, n)
362
+}
363
+
364
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutRetriesWhenNotIdle() {
365
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
366
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
367
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
368
+
369
+	n, err := suite.conn.Read(make([]byte, 10))
370
+	suite.NoError(err)
371
+	suite.Equal(5, n)
372
+}
373
+
374
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutClosesWhenIdle() {
375
+	suite.tracker = newIdleTracker(time.Millisecond)
376
+	suite.conn = connIdleTimeout{
377
+		Conn:    suite.connMock,
378
+		tracker: suite.tracker,
379
+	}
380
+
381
+	time.Sleep(5 * time.Millisecond)
382
+
383
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
384
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
385
+
386
+	n, err := suite.conn.Read(make([]byte, 10))
387
+	suite.Equal(0, n)
388
+
389
+	netErr, ok := err.(net.Error) //nolint: errorlint
390
+	suite.True(ok)
391
+	suite.True(netErr.Timeout())
392
+}
393
+
394
+func (suite *ConnIdleTimeoutTestSuite) TestSharedTrackerPreventsFalseTimeout() {
395
+	connMock2 := &testlib.EssentialsConnMock{}
396
+	conn2 := connIdleTimeout{
397
+		Conn:    connMock2,
398
+		tracker: suite.tracker,
399
+	}
400
+
401
+	connMock2.On("SetWriteDeadline", mock.Anything).Return(nil)
402
+	connMock2.On("Write", mock.Anything).Once().Return(5, nil)
403
+
404
+	_, _ = conn2.Write(make([]byte, 5))
405
+
406
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
407
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
408
+	suite.connMock.On("Read", mock.Anything).Once().Return(3, nil)
409
+
410
+	n, err := suite.conn.Read(make([]byte, 10))
411
+	suite.NoError(err)
412
+	suite.Equal(3, n)
413
+
414
+	connMock2.AssertExpectations(suite.T())
415
+}
416
+
417
+func (suite *ConnIdleTimeoutTestSuite) TestWriteOk() {
418
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
419
+	suite.connMock.On("Write", mock.Anything).Once().Return(5, nil)
420
+
421
+	n, err := suite.conn.Write(make([]byte, 5))
422
+	suite.NoError(err)
423
+	suite.Equal(5, n)
424
+}
425
+
426
+func (suite *ConnIdleTimeoutTestSuite) TestWriteErr() {
427
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
428
+	suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
429
+
430
+	n, err := suite.conn.Write(make([]byte, 5))
431
+	suite.True(errors.Is(err, io.EOF))
432
+	suite.Equal(0, n)
433
+}
434
+
294 435
 func TestConnTraffic(t *testing.T) {
295 436
 	t.Parallel()
296 437
 	suite.Run(t, &ConnTrafficTestSuite{})
@@ -305,3 +446,13 @@ func TestConnProxyProtocol(t *testing.T) {
305 446
 	t.Parallel()
306 447
 	suite.Run(t, &ConnProxyProtocolTestSuite{})
307 448
 }
449
+
450
+func TestIdleTracker(t *testing.T) {
451
+	t.Parallel()
452
+	suite.Run(t, &IdleTrackerTestSuite{})
453
+}
454
+
455
+func TestConnIdleTimeout(t *testing.T) {
456
+	t.Parallel()
457
+	suite.Run(t, &ConnIdleTimeoutTestSuite{})
458
+}

+ 6
- 2
mtglib/internal/dc/view.go Просмотреть файл

@@ -5,15 +5,19 @@ type dcView struct {
5 5
 }
6 6
 
7 7
 func (d dcView) getV4(dc int) []Addr {
8
-	addrs := d.publicConfigs.getV4(dc)
8
+	var addrs []Addr
9
+
9 10
 	addrs = append(addrs, defaultDCAddrSet.getV4(dc)...)
11
+	addrs = append(addrs, d.publicConfigs.getV4(dc)...)
10 12
 
11 13
 	return addrs
12 14
 }
13 15
 
14 16
 func (d dcView) getV6(dc int) []Addr {
15
-	addrs := d.publicConfigs.getV6(dc)
17
+	var addrs []Addr
18
+
16 19
 	addrs = append(addrs, defaultDCAddrSet.getV6(dc)...)
20
+	addrs = append(addrs, d.publicConfigs.getV6(dc)...)
17 21
 
18 22
 	return addrs
19 23
 }

+ 12
- 6
mtglib/internal/doppel/scout.go Просмотреть файл

@@ -61,23 +61,29 @@ func (s Scout) learn(ctx context.Context, url string) (ScoutResult, error) {
61 61
 		client.CloseIdleConnections()
62 62
 	}
63 63
 
64
-	if err != nil || len(results.data) == 0 {
64
+	if err != nil {
65 65
 		return ScoutResult{}, err
66 66
 	}
67 67
 
68
+	data, writeIndex := results.Snapshot()
69
+
70
+	if len(data) == 0 {
71
+		return ScoutResult{}, nil
72
+	}
73
+
68 74
 	var result ScoutResult
69 75
 
70 76
 	// Compute inter-record durations (existing logic).
71 77
 	lastTimestamp := time.Time{}
72 78
 
73
-	for i, v := range results.data {
79
+	for i, v := range data {
74 80
 		if v.recordType != tls.TypeApplicationData {
75 81
 			continue
76 82
 		}
77 83
 
78 84
 		if lastTimestamp.IsZero() {
79 85
 			if i > 0 {
80
-				lastTimestamp = results.data[i-1].timestamp
86
+				lastTimestamp = data[i-1].timestamp
81 87
 			} else {
82 88
 				lastTimestamp = v.timestamp
83 89
 			}
@@ -90,12 +96,12 @@ func (s Scout) learn(ctx context.Context, url string) (ScoutResult, error) {
90 96
 	// Compute cert size: sum of ApplicationData payload between CCS and
91 97
 	// the first client Write (which marks the end of server handshake).
92 98
 	seenCCS := false
93
-	boundary := results.writeIndex
99
+	boundary := writeIndex
94 100
 	if boundary < 0 {
95
-		boundary = len(results.data)
101
+		boundary = len(data)
96 102
 	}
97 103
 
98
-	for i, v := range results.data {
104
+	for i, v := range data {
99 105
 		if i >= boundary {
100 106
 			break
101 107
 		}

+ 20
- 1
mtglib/internal/doppel/scout_conn_collected.go Просмотреть файл

@@ -1,6 +1,10 @@
1 1
 package doppel
2 2
 
3
-import "time"
3
+import (
4
+	"slices"
5
+	"sync"
6
+	"time"
7
+)
4 8
 
5 9
 const (
6 10
 	ScoutConnCollectedPreallocSize = 100
@@ -13,23 +17,38 @@ type ScoutConnResult struct {
13 17
 }
14 18
 
15 19
 type ScoutConnCollected struct {
20
+	mu         sync.Mutex
16 21
 	data       []ScoutConnResult
17 22
 	writeIndex int // index at which client first wrote post-handshake data; -1 if not set
18 23
 }
19 24
 
20 25
 func (s *ScoutConnCollected) Add(record byte, payloadLen int) {
26
+	s.mu.Lock()
21 27
 	s.data = append(s.data, ScoutConnResult{
22 28
 		timestamp:  time.Now(),
23 29
 		recordType: record,
24 30
 		payloadLen: payloadLen,
25 31
 	})
32
+	s.mu.Unlock()
26 33
 }
27 34
 
28 35
 // MarkWrite records the current data length as the handshake boundary.
29 36
 func (s *ScoutConnCollected) MarkWrite() {
37
+	s.mu.Lock()
30 38
 	if s.writeIndex < 0 {
31 39
 		s.writeIndex = len(s.data)
32 40
 	}
41
+	s.mu.Unlock()
42
+}
43
+
44
+// Snapshot returns a copy of the collected data and the write index.
45
+func (s *ScoutConnCollected) Snapshot() ([]ScoutConnResult, int) {
46
+	s.mu.Lock()
47
+	snapshot := slices.Clone(s.data)
48
+	writeIndex := s.writeIndex
49
+	s.mu.Unlock()
50
+
51
+	return snapshot, writeIndex
33 52
 }
34 53
 
35 54
 func NewScoutConnCollected() *ScoutConnCollected {

+ 48
- 4
mtglib/internal/doppel/scout_conn_collected_test.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package doppel
2 2
 
3 3
 import (
4
+	"sync"
4 5
 	"testing"
5 6
 	"time"
6 7
 
@@ -16,8 +17,10 @@ func (suite *ScoutConnCollectedTestSuite) TestAddSingle() {
16 17
 	collected := NewScoutConnCollected()
17 18
 	collected.Add(tls.TypeApplicationData, 100)
18 19
 
19
-	suite.Len(collected.data, 1)
20
-	suite.Equal(byte(tls.TypeApplicationData), collected.data[0].recordType)
20
+	data, _ := collected.Snapshot()
21
+
22
+	suite.Len(data, 1)
23
+	suite.Equal(byte(tls.TypeApplicationData), data[0].recordType)
21 24
 }
22 25
 
23 26
 func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
@@ -31,11 +34,52 @@ func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
31 34
 	time.Sleep(time.Microsecond)
32 35
 	collected.Add(tls.TypeApplicationData, 100)
33 36
 
34
-	for i := 1; i < len(collected.data); i++ {
35
-		suite.True(collected.data[i].timestamp.After(collected.data[i-1].timestamp))
37
+	data, _ := collected.Snapshot()
38
+
39
+	for i := 1; i < len(data); i++ {
40
+		suite.True(data[i].timestamp.After(data[i-1].timestamp))
36 41
 	}
37 42
 }
38 43
 
44
+func (suite *ScoutConnCollectedTestSuite) TestConcurrentAddSnapshot() {
45
+	collected := NewScoutConnCollected()
46
+
47
+	var wg sync.WaitGroup
48
+
49
+	wg.Add(3)
50
+
51
+	go func() {
52
+		defer wg.Done()
53
+
54
+		for i := 0; i < 1000; i++ {
55
+			collected.Add(tls.TypeApplicationData, i)
56
+		}
57
+	}()
58
+
59
+	go func() {
60
+		defer wg.Done()
61
+
62
+		for i := 0; i < 100; i++ {
63
+			collected.MarkWrite()
64
+		}
65
+	}()
66
+
67
+	go func() {
68
+		defer wg.Done()
69
+
70
+		for i := 0; i < 1000; i++ {
71
+			// call Snapshot concurrently to exercise the lock under -race
72
+			collected.Snapshot() //nolint:errcheck
73
+		}
74
+	}()
75
+
76
+	wg.Wait()
77
+
78
+	data, writeIndex := collected.Snapshot()
79
+	suite.Len(data, 1000)
80
+	suite.GreaterOrEqual(writeIndex, 0)
81
+}
82
+
39 83
 func TestScoutConnCollected(t *testing.T) {
40 84
 	t.Parallel()
41 85
 	suite.Run(t, &ScoutConnCollectedTestSuite{})

+ 8
- 4
mtglib/proxy.go Просмотреть файл

@@ -111,11 +111,13 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
111 111
 		return
112 112
 	}
113 113
 
114
+	tracker := newIdleTracker(p.idleTimeout)
115
+
114 116
 	relay.Relay(
115 117
 		ctx,
116 118
 		ctx.logger.Named("relay"),
117
-		connIdleTimeout{Conn: ctx.telegramConn, timeout: p.idleTimeout},
118
-		newCountingConn(connIdleTimeout{Conn: ctx.clientConn, timeout: p.idleTimeout}, p.stats, ctx.secretName),
119
+		connIdleTimeout{Conn: ctx.telegramConn, tracker: tracker},
120
+		newCountingConn(connIdleTimeout{Conn: ctx.clientConn, tracker: tracker}, p.stats, ctx.secretName),
119 121
 	)
120 122
 }
121 123
 
@@ -330,11 +332,13 @@ func (p *Proxy) doDomainFronting(ctx *streamContext, conn *connRewind) {
330 332
 		stream:   p.eventStream,
331 333
 	}
332 334
 
335
+	tracker := newIdleTracker(p.idleTimeout)
336
+
333 337
 	relay.Relay(
334 338
 		ctx,
335 339
 		ctx.logger.Named("domain-fronting"),
336
-		connIdleTimeout{Conn: frontConn, timeout: p.idleTimeout},
337
-		connIdleTimeout{Conn: conn, timeout: p.idleTimeout},
340
+		connIdleTimeout{Conn: frontConn, tracker: tracker},
341
+		connIdleTimeout{Conn: conn, tracker: tracker},
338 342
 	)
339 343
 }
340 344
 

+ 1
- 1
mtglib/proxy_test.go Просмотреть файл

@@ -175,7 +175,7 @@ func (suite *ProxyTestSuite) TestHTTPSRequest() {
175 175
 	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
176 176
 
177 177
 	resp, err := client.Get(addr) //nolint: noctx
178
-	suite.NoError(err)
178
+	suite.Require().NoError(err)
179 179
 
180 180
 	defer resp.Body.Close() //nolint: errcheck
181 181
 

Загрузка…
Отмена
Сохранить