Просмотр исходного кода

Merge remote-tracking branch 'upstream/master' into fix/toml-section-order

# Conflicts:
#	.github/workflows/ci.yaml
#	README.md
#	default.pgo
#	mtglib/internal/tls/fake/client_side.go
#	mtglib/internal/tls/fake/client_side_test.go
#	mtglib/proxy.go
#	mtglib/proxy_opts.go
pull/434/head
Alexey Dolotov 1 месяц назад
Родитель
Сommit
aec151d158

+ 1
- 1
Dockerfile Просмотреть файл

33
 COPY . /app
33
 COPY . /app
34
 
34
 
35
 RUN set -x \
35
 RUN set -x \
36
-  && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always)" \
36
+  && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always 2>/dev/null || echo dev)" \
37
   && go build \
37
   && go build \
38
       -trimpath \
38
       -trimpath \
39
       -mod=readonly \
39
       -mod=readonly \

+ 1
- 1
antireplay/stable_bloom_filter_test.go Просмотреть файл

12
 }
12
 }
13
 
13
 
14
 func (suite *StableBloomFilterTestSuite) TestOp() {
14
 func (suite *StableBloomFilterTestSuite) TestOp() {
15
-	filter := antireplay.NewStableBloomFilter(500, 0.001)
15
+	filter := antireplay.NewStableBloomFilter(100000, 0.001)
16
 
16
 
17
 	suite.False(filter.SeenBefore([]byte{1, 2, 3}))
17
 	suite.False(filter.SeenBefore([]byte{1, 2, 3}))
18
 	suite.False(filter.SeenBefore([]byte{4, 5, 6}))
18
 	suite.False(filter.SeenBefore([]byte{4, 5, 6}))

+ 4
- 4
internal/config/config.go Просмотреть файл

53
 		Blocklist    ListConfig `json:"blocklist"`
53
 		Blocklist    ListConfig `json:"blocklist"`
54
 		Allowlist    ListConfig `json:"allowlist"`
54
 		Allowlist    ListConfig `json:"allowlist"`
55
 		Doppelganger struct {
55
 		Doppelganger struct {
56
-			URLs            []TypeHttpsURL  `json:"urls"`
57
-			Repeats         TypeConcurrency `json:"repeats_per_raid"`
58
-			UpdateEach      TypeDuration    `json:"raid_each"`
59
-			DRS             TypeBool        `json:"drs"`
56
+			URLs       []TypeHttpsURL  `json:"urls"`
57
+			Repeats    TypeConcurrency `json:"repeats_per_raid"`
58
+			UpdateEach TypeDuration    `json:"raid_each"`
59
+			DRS        TypeBool        `json:"drs"`
60
 		} `json:"doppelganger"`
60
 		} `json:"doppelganger"`
61
 	} `json:"defense"`
61
 	} `json:"defense"`
62
 	Network struct {
62
 	Network struct {

+ 4
- 4
internal/config/parse.go Просмотреть файл

48
 			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
48
 			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
49
 		} `toml:"allowlist" json:"allowlist,omitempty"`
49
 		} `toml:"allowlist" json:"allowlist,omitempty"`
50
 		Doppelganger struct {
50
 		Doppelganger struct {
51
-			URLs            []string `toml:"urls" json:"urls,omitempty"`
52
-			Repeats         uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
53
-			UpdateEach      string   `toml:"raid-each" json:"raid_each,omitempty"`
54
-			DRS             bool     `toml:"drs" json:"drs,omitempty"`
51
+			URLs       []string `toml:"urls" json:"urls,omitempty"`
52
+			Repeats    uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
53
+			UpdateEach string   `toml:"raid-each" json:"raid_each,omitempty"`
54
+			DRS        bool     `toml:"drs" json:"drs,omitempty"`
55
 		} `toml:"doppelganger" json:"doppelganger,omitempty"`
55
 		} `toml:"doppelganger" json:"doppelganger,omitempty"`
56
 	} `toml:"defense" json:"defense,omitempty"`
56
 	} `toml:"defense" json:"defense,omitempty"`
57
 	Network struct {
57
 	Network struct {

+ 50
- 5
mtglib/conns.go Просмотреть файл

3
 import (
3
 import (
4
 	"bytes"
4
 	"bytes"
5
 	"context"
5
 	"context"
6
+	"errors"
6
 	"fmt"
7
 	"fmt"
7
 	"io"
8
 	"io"
8
 	"net"
9
 	"net"
10
+	"sync/atomic"
9
 	"time"
11
 	"time"
10
 
12
 
11
 	"github.com/9seconds/mtg/v2/essentials"
13
 	"github.com/9seconds/mtg/v2/essentials"
97
 	}
99
 	}
98
 }
100
 }
99
 
101
 
102
+// idleTracker is a shared idle tracker for a pair of relay connections.
103
+// Both directions update the same timestamp so that activity in one direction
104
+// prevents the other (idle) direction from timing out.
105
+type idleTracker struct {
106
+	lastActive atomic.Pointer[time.Time]
107
+	timeout    time.Duration
108
+}
109
+
110
+func newIdleTracker(timeout time.Duration) *idleTracker {
111
+	t := &idleTracker{timeout: timeout}
112
+	t.touch()
113
+
114
+	return t
115
+}
116
+
117
+func (t *idleTracker) touch() {
118
+	stamp := time.Now()
119
+	t.lastActive.Store(&stamp)
120
+}
121
+
122
+func (t *idleTracker) isIdle() bool {
123
+	return time.Since(*t.lastActive.Load()) >= t.timeout
124
+}
125
+
100
 type connIdleTimeout struct {
126
 type connIdleTimeout struct {
101
 	essentials.Conn
127
 	essentials.Conn
102
 
128
 
103
-	timeout time.Duration
129
+	tracker *idleTracker
104
 }
130
 }
105
 
131
 
106
 func (c connIdleTimeout) Read(b []byte) (int, error) {
132
 func (c connIdleTimeout) Read(b []byte) (int, error) {
107
-	c.SetReadDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
133
+	var netErr net.Error
134
+
135
+	for {
136
+		c.SetReadDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
108
 
137
 
109
-	return c.Conn.Read(b) //nolint: wrapcheck
138
+		n, err := c.Conn.Read(b)
139
+
140
+		switch {
141
+		case err == nil:
142
+			c.tracker.touch()
143
+			return n, nil
144
+		case errors.As(err, &netErr) && netErr.Timeout() && !c.tracker.isIdle():
145
+			continue
146
+		}
147
+
148
+		return n, err
149
+	}
110
 }
150
 }
111
 
151
 
112
 func (c connIdleTimeout) Write(b []byte) (int, error) {
152
 func (c connIdleTimeout) Write(b []byte) (int, error) {
113
-	c.SetWriteDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
153
+	c.SetWriteDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
114
 
154
 
115
-	return c.Conn.Write(b) //nolint: wrapcheck
155
+	n, err := c.Conn.Write(b)
156
+	if n > 0 {
157
+		c.tracker.touch()
158
+	}
159
+
160
+	return n, err //nolint: wrapcheck
116
 }
161
 }

+ 151
- 0
mtglib/conns_internal_test.go Просмотреть файл

16
 	"github.com/stretchr/testify/suite"
16
 	"github.com/stretchr/testify/suite"
17
 )
17
 )
18
 
18
 
19
+type netTimeoutError struct{}
20
+
21
+func (e netTimeoutError) Error() string   { return "i/o timeout" }
22
+func (e netTimeoutError) Timeout() bool   { return true }
23
+func (e netTimeoutError) Temporary() bool { return true }
24
+
19
 type ConnRewindBaseConn struct {
25
 type ConnRewindBaseConn struct {
20
 	testlib.EssentialsConnMock
26
 	testlib.EssentialsConnMock
21
 
27
 
291
 	suite.targetConnMock.AssertExpectations(suite.T())
297
 	suite.targetConnMock.AssertExpectations(suite.T())
292
 }
298
 }
293
 
299
 
300
+type IdleTrackerTestSuite struct {
301
+	suite.Suite
302
+}
303
+
304
+func (suite *IdleTrackerTestSuite) TestNewNotIdle() {
305
+	tracker := newIdleTracker(time.Second)
306
+	suite.False(tracker.isIdle())
307
+}
308
+
309
+func (suite *IdleTrackerTestSuite) TestIdleAfterTimeout() {
310
+	tracker := newIdleTracker(10 * time.Millisecond)
311
+	time.Sleep(20 * time.Millisecond)
312
+
313
+	suite.True(tracker.isIdle())
314
+}
315
+
316
+func (suite *IdleTrackerTestSuite) TestTouchResetsIdle() {
317
+	tracker := newIdleTracker(50 * time.Millisecond)
318
+	time.Sleep(30 * time.Millisecond)
319
+
320
+	tracker.touch()
321
+
322
+	suite.False(tracker.isIdle())
323
+}
324
+
325
+type ConnIdleTimeoutTestSuite struct {
326
+	suite.Suite
327
+
328
+	connMock *testlib.EssentialsConnMock
329
+	tracker  *idleTracker
330
+	conn     connIdleTimeout
331
+}
332
+
333
+func (suite *ConnIdleTimeoutTestSuite) SetupTest() {
334
+	suite.connMock = &testlib.EssentialsConnMock{}
335
+	suite.tracker = newIdleTracker(time.Second)
336
+	suite.conn = connIdleTimeout{
337
+		Conn:    suite.connMock,
338
+		tracker: suite.tracker,
339
+	}
340
+}
341
+
342
+func (suite *ConnIdleTimeoutTestSuite) TearDownTest() {
343
+	suite.connMock.AssertExpectations(suite.T())
344
+}
345
+
346
+func (suite *ConnIdleTimeoutTestSuite) TestReadOk() {
347
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
348
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
349
+
350
+	n, err := suite.conn.Read(make([]byte, 10))
351
+	suite.NoError(err)
352
+	suite.Equal(5, n)
353
+}
354
+
355
+func (suite *ConnIdleTimeoutTestSuite) TestReadNonTimeoutErr() {
356
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
357
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
358
+
359
+	n, err := suite.conn.Read(make([]byte, 10))
360
+	suite.True(errors.Is(err, io.EOF))
361
+	suite.Equal(0, n)
362
+}
363
+
364
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutRetriesWhenNotIdle() {
365
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
366
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
367
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
368
+
369
+	n, err := suite.conn.Read(make([]byte, 10))
370
+	suite.NoError(err)
371
+	suite.Equal(5, n)
372
+}
373
+
374
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutClosesWhenIdle() {
375
+	suite.tracker = newIdleTracker(time.Millisecond)
376
+	suite.conn = connIdleTimeout{
377
+		Conn:    suite.connMock,
378
+		tracker: suite.tracker,
379
+	}
380
+
381
+	time.Sleep(5 * time.Millisecond)
382
+
383
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
384
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
385
+
386
+	n, err := suite.conn.Read(make([]byte, 10))
387
+	suite.Equal(0, n)
388
+
389
+	netErr, ok := err.(net.Error) //nolint: errorlint
390
+	suite.True(ok)
391
+	suite.True(netErr.Timeout())
392
+}
393
+
394
+func (suite *ConnIdleTimeoutTestSuite) TestSharedTrackerPreventsFalseTimeout() {
395
+	connMock2 := &testlib.EssentialsConnMock{}
396
+	conn2 := connIdleTimeout{
397
+		Conn:    connMock2,
398
+		tracker: suite.tracker,
399
+	}
400
+
401
+	connMock2.On("SetWriteDeadline", mock.Anything).Return(nil)
402
+	connMock2.On("Write", mock.Anything).Once().Return(5, nil)
403
+
404
+	_, _ = conn2.Write(make([]byte, 5))
405
+
406
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
407
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
408
+	suite.connMock.On("Read", mock.Anything).Once().Return(3, nil)
409
+
410
+	n, err := suite.conn.Read(make([]byte, 10))
411
+	suite.NoError(err)
412
+	suite.Equal(3, n)
413
+
414
+	connMock2.AssertExpectations(suite.T())
415
+}
416
+
417
+func (suite *ConnIdleTimeoutTestSuite) TestWriteOk() {
418
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
419
+	suite.connMock.On("Write", mock.Anything).Once().Return(5, nil)
420
+
421
+	n, err := suite.conn.Write(make([]byte, 5))
422
+	suite.NoError(err)
423
+	suite.Equal(5, n)
424
+}
425
+
426
+func (suite *ConnIdleTimeoutTestSuite) TestWriteErr() {
427
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
428
+	suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
429
+
430
+	n, err := suite.conn.Write(make([]byte, 5))
431
+	suite.True(errors.Is(err, io.EOF))
432
+	suite.Equal(0, n)
433
+}
434
+
294
 func TestConnTraffic(t *testing.T) {
435
 func TestConnTraffic(t *testing.T) {
295
 	t.Parallel()
436
 	t.Parallel()
296
 	suite.Run(t, &ConnTrafficTestSuite{})
437
 	suite.Run(t, &ConnTrafficTestSuite{})
305
 	t.Parallel()
446
 	t.Parallel()
306
 	suite.Run(t, &ConnProxyProtocolTestSuite{})
447
 	suite.Run(t, &ConnProxyProtocolTestSuite{})
307
 }
448
 }
449
+
450
+func TestIdleTracker(t *testing.T) {
451
+	t.Parallel()
452
+	suite.Run(t, &IdleTrackerTestSuite{})
453
+}
454
+
455
+func TestConnIdleTimeout(t *testing.T) {
456
+	t.Parallel()
457
+	suite.Run(t, &ConnIdleTimeoutTestSuite{})
458
+}

+ 6
- 2
mtglib/internal/dc/view.go Просмотреть файл

5
 }
5
 }
6
 
6
 
7
 func (d dcView) getV4(dc int) []Addr {
7
 func (d dcView) getV4(dc int) []Addr {
8
-	addrs := d.publicConfigs.getV4(dc)
8
+	var addrs []Addr
9
+
9
 	addrs = append(addrs, defaultDCAddrSet.getV4(dc)...)
10
 	addrs = append(addrs, defaultDCAddrSet.getV4(dc)...)
11
+	addrs = append(addrs, d.publicConfigs.getV4(dc)...)
10
 
12
 
11
 	return addrs
13
 	return addrs
12
 }
14
 }
13
 
15
 
14
 func (d dcView) getV6(dc int) []Addr {
16
 func (d dcView) getV6(dc int) []Addr {
15
-	addrs := d.publicConfigs.getV6(dc)
17
+	var addrs []Addr
18
+
16
 	addrs = append(addrs, defaultDCAddrSet.getV6(dc)...)
19
 	addrs = append(addrs, defaultDCAddrSet.getV6(dc)...)
20
+	addrs = append(addrs, d.publicConfigs.getV6(dc)...)
17
 
21
 
18
 	return addrs
22
 	return addrs
19
 }
23
 }

+ 12
- 6
mtglib/internal/doppel/scout.go Просмотреть файл

61
 		client.CloseIdleConnections()
61
 		client.CloseIdleConnections()
62
 	}
62
 	}
63
 
63
 
64
-	if err != nil || len(results.data) == 0 {
64
+	if err != nil {
65
 		return ScoutResult{}, err
65
 		return ScoutResult{}, err
66
 	}
66
 	}
67
 
67
 
68
+	data, writeIndex := results.Snapshot()
69
+
70
+	if len(data) == 0 {
71
+		return ScoutResult{}, nil
72
+	}
73
+
68
 	var result ScoutResult
74
 	var result ScoutResult
69
 
75
 
70
 	// Compute inter-record durations (existing logic).
76
 	// Compute inter-record durations (existing logic).
71
 	lastTimestamp := time.Time{}
77
 	lastTimestamp := time.Time{}
72
 
78
 
73
-	for i, v := range results.data {
79
+	for i, v := range data {
74
 		if v.recordType != tls.TypeApplicationData {
80
 		if v.recordType != tls.TypeApplicationData {
75
 			continue
81
 			continue
76
 		}
82
 		}
77
 
83
 
78
 		if lastTimestamp.IsZero() {
84
 		if lastTimestamp.IsZero() {
79
 			if i > 0 {
85
 			if i > 0 {
80
-				lastTimestamp = results.data[i-1].timestamp
86
+				lastTimestamp = data[i-1].timestamp
81
 			} else {
87
 			} else {
82
 				lastTimestamp = v.timestamp
88
 				lastTimestamp = v.timestamp
83
 			}
89
 			}
90
 	// Compute cert size: sum of ApplicationData payload between CCS and
96
 	// Compute cert size: sum of ApplicationData payload between CCS and
91
 	// the first client Write (which marks the end of server handshake).
97
 	// the first client Write (which marks the end of server handshake).
92
 	seenCCS := false
98
 	seenCCS := false
93
-	boundary := results.writeIndex
99
+	boundary := writeIndex
94
 	if boundary < 0 {
100
 	if boundary < 0 {
95
-		boundary = len(results.data)
101
+		boundary = len(data)
96
 	}
102
 	}
97
 
103
 
98
-	for i, v := range results.data {
104
+	for i, v := range data {
99
 		if i >= boundary {
105
 		if i >= boundary {
100
 			break
106
 			break
101
 		}
107
 		}

+ 20
- 1
mtglib/internal/doppel/scout_conn_collected.go Просмотреть файл

1
 package doppel
1
 package doppel
2
 
2
 
3
-import "time"
3
+import (
4
+	"slices"
5
+	"sync"
6
+	"time"
7
+)
4
 
8
 
5
 const (
9
 const (
6
 	ScoutConnCollectedPreallocSize = 100
10
 	ScoutConnCollectedPreallocSize = 100
13
 }
17
 }
14
 
18
 
15
 type ScoutConnCollected struct {
19
 type ScoutConnCollected struct {
20
+	mu         sync.Mutex
16
 	data       []ScoutConnResult
21
 	data       []ScoutConnResult
17
 	writeIndex int // index at which client first wrote post-handshake data; -1 if not set
22
 	writeIndex int // index at which client first wrote post-handshake data; -1 if not set
18
 }
23
 }
19
 
24
 
20
 func (s *ScoutConnCollected) Add(record byte, payloadLen int) {
25
 func (s *ScoutConnCollected) Add(record byte, payloadLen int) {
26
+	s.mu.Lock()
21
 	s.data = append(s.data, ScoutConnResult{
27
 	s.data = append(s.data, ScoutConnResult{
22
 		timestamp:  time.Now(),
28
 		timestamp:  time.Now(),
23
 		recordType: record,
29
 		recordType: record,
24
 		payloadLen: payloadLen,
30
 		payloadLen: payloadLen,
25
 	})
31
 	})
32
+	s.mu.Unlock()
26
 }
33
 }
27
 
34
 
28
 // MarkWrite records the current data length as the handshake boundary.
35
 // MarkWrite records the current data length as the handshake boundary.
29
 func (s *ScoutConnCollected) MarkWrite() {
36
 func (s *ScoutConnCollected) MarkWrite() {
37
+	s.mu.Lock()
30
 	if s.writeIndex < 0 {
38
 	if s.writeIndex < 0 {
31
 		s.writeIndex = len(s.data)
39
 		s.writeIndex = len(s.data)
32
 	}
40
 	}
41
+	s.mu.Unlock()
42
+}
43
+
44
+// Snapshot returns a copy of the collected data and the write index.
45
+func (s *ScoutConnCollected) Snapshot() ([]ScoutConnResult, int) {
46
+	s.mu.Lock()
47
+	snapshot := slices.Clone(s.data)
48
+	writeIndex := s.writeIndex
49
+	s.mu.Unlock()
50
+
51
+	return snapshot, writeIndex
33
 }
52
 }
34
 
53
 
35
 func NewScoutConnCollected() *ScoutConnCollected {
54
 func NewScoutConnCollected() *ScoutConnCollected {

+ 48
- 4
mtglib/internal/doppel/scout_conn_collected_test.go Просмотреть файл

1
 package doppel
1
 package doppel
2
 
2
 
3
 import (
3
 import (
4
+	"sync"
4
 	"testing"
5
 	"testing"
5
 	"time"
6
 	"time"
6
 
7
 
16
 	collected := NewScoutConnCollected()
17
 	collected := NewScoutConnCollected()
17
 	collected.Add(tls.TypeApplicationData, 100)
18
 	collected.Add(tls.TypeApplicationData, 100)
18
 
19
 
19
-	suite.Len(collected.data, 1)
20
-	suite.Equal(byte(tls.TypeApplicationData), collected.data[0].recordType)
20
+	data, _ := collected.Snapshot()
21
+
22
+	suite.Len(data, 1)
23
+	suite.Equal(byte(tls.TypeApplicationData), data[0].recordType)
21
 }
24
 }
22
 
25
 
23
 func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
26
 func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
31
 	time.Sleep(time.Microsecond)
34
 	time.Sleep(time.Microsecond)
32
 	collected.Add(tls.TypeApplicationData, 100)
35
 	collected.Add(tls.TypeApplicationData, 100)
33
 
36
 
34
-	for i := 1; i < len(collected.data); i++ {
35
-		suite.True(collected.data[i].timestamp.After(collected.data[i-1].timestamp))
37
+	data, _ := collected.Snapshot()
38
+
39
+	for i := 1; i < len(data); i++ {
40
+		suite.True(data[i].timestamp.After(data[i-1].timestamp))
36
 	}
41
 	}
37
 }
42
 }
38
 
43
 
44
+func (suite *ScoutConnCollectedTestSuite) TestConcurrentAddSnapshot() {
45
+	collected := NewScoutConnCollected()
46
+
47
+	var wg sync.WaitGroup
48
+
49
+	wg.Add(3)
50
+
51
+	go func() {
52
+		defer wg.Done()
53
+
54
+		for i := 0; i < 1000; i++ {
55
+			collected.Add(tls.TypeApplicationData, i)
56
+		}
57
+	}()
58
+
59
+	go func() {
60
+		defer wg.Done()
61
+
62
+		for i := 0; i < 100; i++ {
63
+			collected.MarkWrite()
64
+		}
65
+	}()
66
+
67
+	go func() {
68
+		defer wg.Done()
69
+
70
+		for i := 0; i < 1000; i++ {
71
+			// call Snapshot concurrently to exercise the lock under -race
72
+			collected.Snapshot() //nolint:errcheck
73
+		}
74
+	}()
75
+
76
+	wg.Wait()
77
+
78
+	data, writeIndex := collected.Snapshot()
79
+	suite.Len(data, 1000)
80
+	suite.GreaterOrEqual(writeIndex, 0)
81
+}
82
+
39
 func TestScoutConnCollected(t *testing.T) {
83
 func TestScoutConnCollected(t *testing.T) {
40
 	t.Parallel()
84
 	t.Parallel()
41
 	suite.Run(t, &ScoutConnCollectedTestSuite{})
85
 	suite.Run(t, &ScoutConnCollectedTestSuite{})

+ 122
- 16
mtglib/internal/tls/fake/client_side.go Просмотреть файл

6
 	"crypto/sha256"
6
 	"crypto/sha256"
7
 	"crypto/subtle"
7
 	"crypto/subtle"
8
 	"encoding/binary"
8
 	"encoding/binary"
9
+	"errors"
9
 	"fmt"
10
 	"fmt"
10
 	"io"
11
 	"io"
11
 	"net"
12
 	"net"
23
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
24
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
24
 
25
 
25
 	sniDNSNamesListType = 0
26
 	sniDNSNamesListType = 0
27
+
28
+	// maxContinuationRecords limits the number of continuation TLS records
29
+	// that reassembleTLSHandshake will read. This prevents resource exhaustion
30
+	// from adversarial fragmentation.
31
+	maxContinuationRecords = 10
26
 )
32
 )
27
 
33
 
28
 var (
34
 var (
71
 	}
77
 	}
72
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
78
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
73
 
79
 
80
+	reassembled, err := reassembleTLSHandshake(conn)
81
+	if err != nil {
82
+		return nil, fmt.Errorf("cannot reassemble TLS records: %w", err)
83
+	}
84
+
74
 	handshakeCopyBuf := &bytes.Buffer{}
85
 	handshakeCopyBuf := &bytes.Buffer{}
75
-	reader := io.TeeReader(conn, handshakeCopyBuf)
86
+	reader := io.TeeReader(reassembled, handshakeCopyBuf)
76
 
87
 
77
-	reader, err := parseTLSHeader(reader)
78
-	if err != nil {
79
-		return nil, fmt.Errorf("cannot parse tls header: %w", err)
88
+	// Skip the TLS record header (validated during reassembly).
89
+	// The header still flows through TeeReader into handshakeCopyBuf for HMAC.
90
+	if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil {
91
+		return nil, fmt.Errorf("cannot skip tls header: %w", err)
80
 	}
92
 	}
81
 
93
 
82
 	reader, err = parseHandshakeHeader(reader)
94
 	reader, err = parseHandshakeHeader(reader)
135
 	return nil, ErrBadDigest
147
 	return nil, ErrBadDigest
136
 }
148
 }
137
 
149
 
138
-func parseTLSHeader(r io.Reader) (io.Reader, error) {
139
-	// record_type(1) + version(2) + size(2)
140
-	//   16 - type is 0x16 (handshake record)
141
-	//   03 01 - protocol version is "3,1" (also known as TLS 1.0)
142
-	//   00 f8 - 0xF8 (248) bytes of handshake message follows
143
-	header := [1 + 2 + 2]byte{}
144
-
145
-	if _, err := io.ReadFull(r, header[:]); err != nil {
150
+// reassembleTLSHandshake reads one or more TLS records from conn,
151
+// validates the record type and version, and reassembles fragmented
152
+// handshake payloads into a single TLS record.
153
+//
154
+// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented
155
+// across multiple TLS records. DPI bypass tools like ByeDPI use this
156
+// to evade censorship.
157
+//
158
+// The returned buffer contains the full TLS record (header + payload)
159
+// so that callers can include the header in HMAC computation.
160
+func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) {
161
+	header := [tls.SizeHeader]byte{}
162
+
163
+	if _, err := io.ReadFull(conn, header[:]); err != nil {
146
 		return nil, fmt.Errorf("cannot read record header: %w", err)
164
 		return nil, fmt.Errorf("cannot read record header: %w", err)
147
 	}
165
 	}
148
 
166
 
167
+	length := int64(binary.BigEndian.Uint16(header[3:]))
168
+	payload := &bytes.Buffer{}
169
+
170
+	if _, err := io.CopyN(payload, conn, length); err != nil {
171
+		return nil, fmt.Errorf("cannot read record payload: %w", err)
172
+	}
173
+
149
 	if header[0] != tls.TypeHandshake {
174
 	if header[0] != tls.TypeHandshake {
150
 		return nil, fmt.Errorf("unexpected record type %#x", header[0])
175
 		return nil, fmt.Errorf("unexpected record type %#x", header[0])
151
 	}
176
 	}
154
 		return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
179
 		return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
155
 	}
180
 	}
156
 
181
 
157
-	length := int64(binary.BigEndian.Uint16(header[3:]))
158
-	buf := &bytes.Buffer{}
182
+	// Reassemble fragmented payload. continuationCount caps the total
183
+	// number of continuation records across both phases below.
184
+	continuationCount := 0
159
 
185
 
160
-	_, err := io.CopyN(buf, r, length)
186
+	// Phase 1: read continuation records until we have at least the
187
+	// 4-byte handshake header (type + uint24 length) to determine the
188
+	// expected total size.
189
+	for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ {
190
+		prevLen := payload.Len()
161
 
191
 
162
-	return buf, err
192
+		if err := readContinuationRecord(conn, payload); err != nil {
193
+			payload.Truncate(prevLen) // discard partial data on error
194
+
195
+			if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
196
+				break // no more records — let downstream parsing handle what we have
197
+			}
198
+
199
+			return nil, err
200
+		}
201
+	}
202
+
203
+	// Phase 2: we know the expected handshake size — read remaining
204
+	// continuation records until the payload is complete.
205
+	if payload.Len() >= 4 {
206
+		p := payload.Bytes()
207
+		expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3]))
208
+
209
+		if expectedTotal > 0xFFFF {
210
+			return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal)
211
+		}
212
+
213
+		for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ {
214
+			if err := readContinuationRecord(conn, payload); err != nil {
215
+				return nil, err
216
+			}
217
+		}
218
+
219
+		if payload.Len() < expectedTotal {
220
+			return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records")
221
+		}
222
+
223
+		payload.Truncate(expectedTotal)
224
+	}
225
+
226
+	if payload.Len() > 0xFFFF {
227
+		return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len())
228
+	}
229
+
230
+	// Reconstruct a single TLS record with the reassembled payload.
231
+	result := &bytes.Buffer{}
232
+	result.Grow(tls.SizeHeader + payload.Len())
233
+	result.Write(header[:3])
234
+	binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails
235
+	result.Write(payload.Bytes())
236
+
237
+	return result, nil
238
+}
239
+
240
+// readContinuationRecord reads the next TLS record header and appends its
241
+// full payload to dst. It returns an error if the record is not a handshake
242
+// record.
243
+func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error {
244
+	nextHeader := [tls.SizeHeader]byte{}
245
+
246
+	if _, err := io.ReadFull(conn, nextHeader[:]); err != nil {
247
+		return fmt.Errorf("cannot read continuation record header: %w", err)
248
+	}
249
+
250
+	if nextHeader[0] != tls.TypeHandshake {
251
+		return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0])
252
+	}
253
+
254
+	if nextHeader[1] != 3 || nextHeader[2] != 1 {
255
+		return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2])
256
+	}
257
+
258
+	nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:]))
259
+
260
+	if nextLength == 0 {
261
+		return fmt.Errorf("zero-length continuation record")
262
+	}
263
+
264
+	if _, err := io.CopyN(dst, conn, nextLength); err != nil {
265
+		return fmt.Errorf("cannot read continuation record payload: %w", err)
266
+	}
267
+
268
+	return nil
163
 }
269
 }
164
 
270
 
165
 func parseHandshakeHeader(r io.Reader) (io.Reader, error) {
271
 func parseHandshakeHeader(r io.Reader) (io.Reader, error) {

+ 272
- 0
mtglib/internal/tls/fake/client_side_test.go Просмотреть файл

530
 	t.Parallel()
530
 	t.Parallel()
531
 	suite.Run(t, &ReadClientHelloMultiTestSuite{})
531
 	suite.Run(t, &ReadClientHelloMultiTestSuite{})
532
 }
532
 }
533
+
534
+// --- Fragmented TLS record tests ---
535
+
536
+// fragmentTLSRecord splits a single TLS record into n TLS records by
537
+// dividing the payload into roughly equal parts. Each part gets its own
538
+// TLS record header with the same record type and version.
539
+func fragmentTLSRecord(t testing.TB, full []byte, n int) []byte {
540
+	t.Helper()
541
+
542
+	recordType := full[0]
543
+	version := full[1:3]
544
+	payload := full[tls.SizeHeader:]
545
+
546
+	chunkSize := len(payload) / n
547
+	result := &bytes.Buffer{}
548
+
549
+	for i := 0; i < n; i++ {
550
+		start := i * chunkSize
551
+		end := start + chunkSize
552
+
553
+		if i == n-1 {
554
+			end = len(payload)
555
+		}
556
+
557
+		chunk := payload[start:end]
558
+		result.WriteByte(recordType)
559
+		result.Write(version)
560
+		require.NoError(t, binary.Write(result, binary.BigEndian, uint16(len(chunk))))
561
+		result.Write(chunk)
562
+	}
563
+
564
+	return result.Bytes()
565
+}
566
+
567
+// splitPayloadAt creates two TLS records from a single record by splitting
568
+// the payload at the given byte position.
569
+func splitPayloadAt(t testing.TB, full []byte, pos int) []byte {
570
+	t.Helper()
571
+
572
+	payload := full[tls.SizeHeader:]
573
+	buf := &bytes.Buffer{}
574
+
575
+	buf.WriteByte(tls.TypeHandshake)
576
+	buf.Write(full[1:3])
577
+	require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(pos)))
578
+	buf.Write(payload[:pos])
579
+
580
+	buf.WriteByte(tls.TypeHandshake)
581
+	buf.Write(full[1:3])
582
+	require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(len(payload)-pos)))
583
+	buf.Write(payload[pos:])
584
+
585
+	return buf.Bytes()
586
+}
587
+
588
+type ParseClientHelloFragmentedTestSuite struct {
589
+	suite.Suite
590
+
591
+	secret   mtglib.Secret
592
+	snapshot *clientHelloSnapshot
593
+}
594
+
595
+func (s *ParseClientHelloFragmentedTestSuite) SetupSuite() {
596
+	parsed, err := mtglib.ParseSecret(
597
+		"ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d",
598
+	)
599
+	require.NoError(s.T(), err)
600
+
601
+	s.secret = parsed
602
+
603
+	fileData, err := os.ReadFile("testdata/client-hello-ok-19dfe38384b9884b.json")
604
+	require.NoError(s.T(), err)
605
+
606
+	s.snapshot = &clientHelloSnapshot{}
607
+	require.NoError(s.T(), json.Unmarshal(fileData, s.snapshot))
608
+}
609
+
610
+func (s *ParseClientHelloFragmentedTestSuite) makeConn(data []byte) *parseClientHelloConnMock {
611
+	readBuf := &bytes.Buffer{}
612
+	readBuf.Write(data)
613
+
614
+	connMock := &parseClientHelloConnMock{
615
+		readBuf: readBuf,
616
+	}
617
+
618
+	connMock.
619
+		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
620
+		Twice().
621
+		Return(nil)
622
+
623
+	return connMock
624
+}
625
+
626
+func (s *ParseClientHelloFragmentedTestSuite) TestReassemblySuccess() {
627
+	full := s.snapshot.GetFull()
628
+
629
+	tests := []struct {
630
+		name string
631
+		data []byte
632
+	}{
633
+		{"two equal fragments", fragmentTLSRecord(s.T(), full, 2)},
634
+		{"three equal fragments", fragmentTLSRecord(s.T(), full, 3)},
635
+		{"single byte first fragment", splitPayloadAt(s.T(), full, 1)},
636
+		{"three byte first fragment", splitPayloadAt(s.T(), full, 3)},
637
+	}
638
+
639
+	for _, tt := range tests {
640
+		s.Run(tt.name, func() {
641
+			connMock := s.makeConn(tt.data)
642
+			defer connMock.AssertExpectations(s.T())
643
+
644
+			hello, err := fake.ReadClientHello(
645
+				connMock,
646
+				s.secret.Key[:],
647
+				s.secret.Host,
648
+				TolerateTime,
649
+			)
650
+			s.Require().NoError(err)
651
+
652
+			s.Equal(s.snapshot.GetRandom(), hello.Random[:])
653
+			s.Equal(s.snapshot.GetSessionID(), hello.SessionID)
654
+			s.Equal(uint16(s.snapshot.CipherSuite), hello.CipherSuite)
655
+		})
656
+	}
657
+}
658
+
659
+func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
660
+	full := s.snapshot.GetFull()
661
+	payload := full[tls.SizeHeader:]
662
+
663
+	tests := []struct {
664
+		name      string
665
+		buildData func() []byte
666
+		errMsg    string
667
+	}{
668
+		{
669
+			name: "wrong continuation record type",
670
+			buildData: func() []byte {
671
+				buf := &bytes.Buffer{}
672
+				buf.WriteByte(tls.TypeHandshake)
673
+				buf.Write(full[1:3])
674
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
675
+				buf.Write(payload[:10])
676
+				// Wrong type: application data instead of handshake
677
+				buf.WriteByte(tls.TypeApplicationData)
678
+				buf.Write(full[1:3])
679
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10)))
680
+				buf.Write(payload[10:])
681
+				return buf.Bytes()
682
+			},
683
+			errMsg: "unexpected continuation record type",
684
+		},
685
+		{
686
+			name: "too many continuation records",
687
+			buildData: func() []byte {
688
+				// Handshake header claiming 256 bytes, but we only send 1 byte per continuation
689
+				handshakePayload := []byte{0x01, 0x00, 0x01, 0x00}
690
+				buf := &bytes.Buffer{}
691
+				buf.WriteByte(tls.TypeHandshake)
692
+				buf.Write([]byte{3, 1})
693
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload))))
694
+				buf.Write(handshakePayload)
695
+				for range 11 {
696
+					buf.WriteByte(tls.TypeHandshake)
697
+					buf.Write([]byte{3, 1})
698
+					require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(1)))
699
+					buf.WriteByte(0xAB)
700
+				}
701
+				return buf.Bytes()
702
+			},
703
+			errMsg: "too many continuation records",
704
+		},
705
+		{
706
+			name: "zero-length continuation record",
707
+			buildData: func() []byte {
708
+				buf := &bytes.Buffer{}
709
+				buf.WriteByte(tls.TypeHandshake)
710
+				buf.Write(full[1:3])
711
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
712
+				buf.Write(payload[:10])
713
+				// Valid header but zero-length payload
714
+				buf.WriteByte(tls.TypeHandshake)
715
+				buf.Write(full[1:3])
716
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
717
+				return buf.Bytes()
718
+			},
719
+			errMsg: "zero-length continuation record",
720
+		},
721
+		{
722
+			name: "wrong continuation record version",
723
+			buildData: func() []byte {
724
+				buf := &bytes.Buffer{}
725
+				buf.WriteByte(tls.TypeHandshake)
726
+				buf.Write(full[1:3])
727
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
728
+				buf.Write(payload[:10])
729
+				// Wrong version: 3.3 instead of 3.1
730
+				buf.WriteByte(tls.TypeHandshake)
731
+				buf.Write([]byte{3, 3})
732
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10)))
733
+				buf.Write(payload[10:])
734
+				return buf.Bytes()
735
+			},
736
+			errMsg: "unexpected continuation record version",
737
+		},
738
+		{
739
+			name: "handshake message too large",
740
+			buildData: func() []byte {
741
+				// Handshake header claiming 0x010000 (65536) bytes — exceeds 0xFFFF limit
742
+				handshakePayload := []byte{0x01, 0x01, 0x00, 0x00}
743
+				buf := &bytes.Buffer{}
744
+				buf.WriteByte(tls.TypeHandshake)
745
+				buf.Write([]byte{3, 1})
746
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload))))
747
+				buf.Write(handshakePayload)
748
+				return buf.Bytes()
749
+			},
750
+			errMsg: "handshake message too large",
751
+		},
752
+		{
753
+			name: "truncated continuation record header",
754
+			buildData: func() []byte {
755
+				buf := &bytes.Buffer{}
756
+				buf.WriteByte(tls.TypeHandshake)
757
+				buf.Write(full[1:3])
758
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
759
+				buf.Write(payload[:10])
760
+				// Connection ends mid-header (only 2 bytes)
761
+				buf.WriteByte(tls.TypeHandshake)
762
+				buf.WriteByte(3)
763
+				return buf.Bytes()
764
+			},
765
+			errMsg: "cannot read continuation record header",
766
+		},
767
+		{
768
+			name: "truncated continuation record payload",
769
+			buildData: func() []byte {
770
+				buf := &bytes.Buffer{}
771
+				buf.WriteByte(tls.TypeHandshake)
772
+				buf.Write(full[1:3])
773
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
774
+				buf.Write(payload[:10])
775
+				// Claims 100 bytes but no payload follows
776
+				buf.WriteByte(tls.TypeHandshake)
777
+				buf.Write(full[1:3])
778
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
779
+				return buf.Bytes()
780
+			},
781
+			errMsg: "cannot read continuation record payload",
782
+		},
783
+	}
784
+
785
+	for _, tt := range tests {
786
+		s.Run(tt.name, func() {
787
+			connMock := s.makeConn(tt.buildData())
788
+			defer connMock.AssertExpectations(s.T())
789
+
790
+			_, err := fake.ReadClientHello(
791
+				connMock,
792
+				s.secret.Key[:],
793
+				s.secret.Host,
794
+				TolerateTime,
795
+			)
796
+			s.ErrorContains(err, tt.errMsg)
797
+		})
798
+	}
799
+}
800
+
801
+func TestParseClientHelloFragmented(t *testing.T) {
802
+	t.Parallel()
803
+	suite.Run(t, &ParseClientHelloFragmentedTestSuite{})
804
+}

+ 8
- 4
mtglib/proxy.go Просмотреть файл

111
 		return
111
 		return
112
 	}
112
 	}
113
 
113
 
114
+	tracker := newIdleTracker(p.idleTimeout)
115
+
114
 	relay.Relay(
116
 	relay.Relay(
115
 		ctx,
117
 		ctx,
116
 		ctx.logger.Named("relay"),
118
 		ctx.logger.Named("relay"),
117
-		connIdleTimeout{Conn: ctx.telegramConn, timeout: p.idleTimeout},
118
-		newCountingConn(connIdleTimeout{Conn: ctx.clientConn, timeout: p.idleTimeout}, p.stats, ctx.secretName),
119
+		connIdleTimeout{Conn: ctx.telegramConn, tracker: tracker},
120
+		newCountingConn(connIdleTimeout{Conn: ctx.clientConn, tracker: tracker}, p.stats, ctx.secretName),
119
 	)
121
 	)
120
 }
122
 }
121
 
123
 
330
 		stream:   p.eventStream,
332
 		stream:   p.eventStream,
331
 	}
333
 	}
332
 
334
 
335
+	tracker := newIdleTracker(p.idleTimeout)
336
+
333
 	relay.Relay(
337
 	relay.Relay(
334
 		ctx,
338
 		ctx,
335
 		ctx.logger.Named("domain-fronting"),
339
 		ctx.logger.Named("domain-fronting"),
336
-		connIdleTimeout{Conn: frontConn, timeout: p.idleTimeout},
337
-		connIdleTimeout{Conn: conn, timeout: p.idleTimeout},
340
+		connIdleTimeout{Conn: frontConn, tracker: tracker},
341
+		connIdleTimeout{Conn: conn, tracker: tracker},
338
 	)
342
 	)
339
 }
343
 }
340
 
344
 

+ 1
- 1
mtglib/proxy_test.go Просмотреть файл

175
 	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
175
 	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
176
 
176
 
177
 	resp, err := client.Get(addr) //nolint: noctx
177
 	resp, err := client.Get(addr) //nolint: noctx
178
-	suite.NoError(err)
178
+	suite.Require().NoError(err)
179
 
179
 
180
 	defer resp.Body.Close() //nolint: errcheck
180
 	defer resp.Body.Close() //nolint: errcheck
181
 
181
 

Загрузка…
Отмена
Сохранить