Просмотр исходного кода

Merge remote-tracking branch 'upstream/master' into fix/toml-section-order

# Conflicts:
#	.github/workflows/ci.yaml
#	README.md
#	default.pgo
#	mtglib/internal/tls/fake/client_side.go
#	mtglib/internal/tls/fake/client_side_test.go
#	mtglib/proxy.go
#	mtglib/proxy_opts.go
pull/434/head
Alexey Dolotov 1 месяц назад
Родитель
Сommit
aec151d158

+ 1
- 1
Dockerfile Просмотреть файл

@@ -33,7 +33,7 @@ RUN go mod download
33 33
 COPY . /app
34 34
 
35 35
 RUN set -x \
36
-  && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always)" \
36
+  && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always 2>/dev/null || echo dev)" \
37 37
   && go build \
38 38
       -trimpath \
39 39
       -mod=readonly \

+ 1
- 1
antireplay/stable_bloom_filter_test.go Просмотреть файл

@@ -12,7 +12,7 @@ type StableBloomFilterTestSuite struct {
12 12
 }
13 13
 
14 14
 func (suite *StableBloomFilterTestSuite) TestOp() {
15
-	filter := antireplay.NewStableBloomFilter(500, 0.001)
15
+	filter := antireplay.NewStableBloomFilter(100000, 0.001)
16 16
 
17 17
 	suite.False(filter.SeenBefore([]byte{1, 2, 3}))
18 18
 	suite.False(filter.SeenBefore([]byte{4, 5, 6}))

+ 4
- 4
internal/config/config.go Просмотреть файл

@@ -53,10 +53,10 @@ type Config struct {
53 53
 		Blocklist    ListConfig `json:"blocklist"`
54 54
 		Allowlist    ListConfig `json:"allowlist"`
55 55
 		Doppelganger struct {
56
-			URLs            []TypeHttpsURL  `json:"urls"`
57
-			Repeats         TypeConcurrency `json:"repeats_per_raid"`
58
-			UpdateEach      TypeDuration    `json:"raid_each"`
59
-			DRS             TypeBool        `json:"drs"`
56
+			URLs       []TypeHttpsURL  `json:"urls"`
57
+			Repeats    TypeConcurrency `json:"repeats_per_raid"`
58
+			UpdateEach TypeDuration    `json:"raid_each"`
59
+			DRS        TypeBool        `json:"drs"`
60 60
 		} `json:"doppelganger"`
61 61
 	} `json:"defense"`
62 62
 	Network struct {

+ 4
- 4
internal/config/parse.go Просмотреть файл

@@ -48,10 +48,10 @@ type tomlConfig struct {
48 48
 			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
49 49
 		} `toml:"allowlist" json:"allowlist,omitempty"`
50 50
 		Doppelganger struct {
51
-			URLs            []string `toml:"urls" json:"urls,omitempty"`
52
-			Repeats         uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
53
-			UpdateEach      string   `toml:"raid-each" json:"raid_each,omitempty"`
54
-			DRS             bool     `toml:"drs" json:"drs,omitempty"`
51
+			URLs       []string `toml:"urls" json:"urls,omitempty"`
52
+			Repeats    uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
53
+			UpdateEach string   `toml:"raid-each" json:"raid_each,omitempty"`
54
+			DRS        bool     `toml:"drs" json:"drs,omitempty"`
55 55
 		} `toml:"doppelganger" json:"doppelganger,omitempty"`
56 56
 	} `toml:"defense" json:"defense,omitempty"`
57 57
 	Network struct {

+ 50
- 5
mtglib/conns.go Просмотреть файл

@@ -3,9 +3,11 @@ package mtglib
3 3
 import (
4 4
 	"bytes"
5 5
 	"context"
6
+	"errors"
6 7
 	"fmt"
7 8
 	"io"
8 9
 	"net"
10
+	"sync/atomic"
9 11
 	"time"
10 12
 
11 13
 	"github.com/9seconds/mtg/v2/essentials"
@@ -97,20 +99,63 @@ func newConnProxyProtocol(source, target essentials.Conn) *connProxyProtocol {
97 99
 	}
98 100
 }
99 101
 
102
+// idleTracker is a shared idle tracker for a pair of relay connections.
103
+// Both directions update the same timestamp so that activity in one direction
104
+// prevents the other (idle) direction from timing out.
105
+type idleTracker struct {
106
+	lastActive atomic.Pointer[time.Time]
107
+	timeout    time.Duration
108
+}
109
+
110
+func newIdleTracker(timeout time.Duration) *idleTracker {
111
+	t := &idleTracker{timeout: timeout}
112
+	t.touch()
113
+
114
+	return t
115
+}
116
+
117
+func (t *idleTracker) touch() {
118
+	stamp := time.Now()
119
+	t.lastActive.Store(&stamp)
120
+}
121
+
122
+func (t *idleTracker) isIdle() bool {
123
+	return time.Since(*t.lastActive.Load()) >= t.timeout
124
+}
125
+
100 126
 type connIdleTimeout struct {
101 127
 	essentials.Conn
102 128
 
103
-	timeout time.Duration
129
+	tracker *idleTracker
104 130
 }
105 131
 
106 132
 func (c connIdleTimeout) Read(b []byte) (int, error) {
107
-	c.SetReadDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
133
+	var netErr net.Error
134
+
135
+	for {
136
+		c.SetReadDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
108 137
 
109
-	return c.Conn.Read(b) //nolint: wrapcheck
138
+		n, err := c.Conn.Read(b)
139
+
140
+		switch {
141
+		case err == nil:
142
+			c.tracker.touch()
143
+			return n, nil
144
+		case errors.As(err, &netErr) && netErr.Timeout() && !c.tracker.isIdle():
145
+			continue
146
+		}
147
+
148
+		return n, err
149
+	}
110 150
 }
111 151
 
112 152
 func (c connIdleTimeout) Write(b []byte) (int, error) {
113
-	c.SetWriteDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
153
+	c.SetWriteDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
114 154
 
115
-	return c.Conn.Write(b) //nolint: wrapcheck
155
+	n, err := c.Conn.Write(b)
156
+	if n > 0 {
157
+		c.tracker.touch()
158
+	}
159
+
160
+	return n, err //nolint: wrapcheck
116 161
 }

+ 151
- 0
mtglib/conns_internal_test.go Просмотреть файл

@@ -16,6 +16,12 @@ import (
16 16
 	"github.com/stretchr/testify/suite"
17 17
 )
18 18
 
19
+type netTimeoutError struct{}
20
+
21
+func (e netTimeoutError) Error() string   { return "i/o timeout" }
22
+func (e netTimeoutError) Timeout() bool   { return true }
23
+func (e netTimeoutError) Temporary() bool { return true }
24
+
19 25
 type ConnRewindBaseConn struct {
20 26
 	testlib.EssentialsConnMock
21 27
 
@@ -291,6 +297,141 @@ func (suite *ConnProxyProtocolTestSuite) TearDownTest() {
291 297
 	suite.targetConnMock.AssertExpectations(suite.T())
292 298
 }
293 299
 
300
+type IdleTrackerTestSuite struct {
301
+	suite.Suite
302
+}
303
+
304
+func (suite *IdleTrackerTestSuite) TestNewNotIdle() {
305
+	tracker := newIdleTracker(time.Second)
306
+	suite.False(tracker.isIdle())
307
+}
308
+
309
+func (suite *IdleTrackerTestSuite) TestIdleAfterTimeout() {
310
+	tracker := newIdleTracker(10 * time.Millisecond)
311
+	time.Sleep(20 * time.Millisecond)
312
+
313
+	suite.True(tracker.isIdle())
314
+}
315
+
316
+func (suite *IdleTrackerTestSuite) TestTouchResetsIdle() {
317
+	tracker := newIdleTracker(50 * time.Millisecond)
318
+	time.Sleep(30 * time.Millisecond)
319
+
320
+	tracker.touch()
321
+
322
+	suite.False(tracker.isIdle())
323
+}
324
+
325
+type ConnIdleTimeoutTestSuite struct {
326
+	suite.Suite
327
+
328
+	connMock *testlib.EssentialsConnMock
329
+	tracker  *idleTracker
330
+	conn     connIdleTimeout
331
+}
332
+
333
+func (suite *ConnIdleTimeoutTestSuite) SetupTest() {
334
+	suite.connMock = &testlib.EssentialsConnMock{}
335
+	suite.tracker = newIdleTracker(time.Second)
336
+	suite.conn = connIdleTimeout{
337
+		Conn:    suite.connMock,
338
+		tracker: suite.tracker,
339
+	}
340
+}
341
+
342
+func (suite *ConnIdleTimeoutTestSuite) TearDownTest() {
343
+	suite.connMock.AssertExpectations(suite.T())
344
+}
345
+
346
+func (suite *ConnIdleTimeoutTestSuite) TestReadOk() {
347
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
348
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
349
+
350
+	n, err := suite.conn.Read(make([]byte, 10))
351
+	suite.NoError(err)
352
+	suite.Equal(5, n)
353
+}
354
+
355
+func (suite *ConnIdleTimeoutTestSuite) TestReadNonTimeoutErr() {
356
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
357
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
358
+
359
+	n, err := suite.conn.Read(make([]byte, 10))
360
+	suite.True(errors.Is(err, io.EOF))
361
+	suite.Equal(0, n)
362
+}
363
+
364
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutRetriesWhenNotIdle() {
365
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
366
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
367
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
368
+
369
+	n, err := suite.conn.Read(make([]byte, 10))
370
+	suite.NoError(err)
371
+	suite.Equal(5, n)
372
+}
373
+
374
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutClosesWhenIdle() {
375
+	suite.tracker = newIdleTracker(time.Millisecond)
376
+	suite.conn = connIdleTimeout{
377
+		Conn:    suite.connMock,
378
+		tracker: suite.tracker,
379
+	}
380
+
381
+	time.Sleep(5 * time.Millisecond)
382
+
383
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
384
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
385
+
386
+	n, err := suite.conn.Read(make([]byte, 10))
387
+	suite.Equal(0, n)
388
+
389
+	netErr, ok := err.(net.Error) //nolint: errorlint
390
+	suite.True(ok)
391
+	suite.True(netErr.Timeout())
392
+}
393
+
394
+func (suite *ConnIdleTimeoutTestSuite) TestSharedTrackerPreventsFalseTimeout() {
395
+	connMock2 := &testlib.EssentialsConnMock{}
396
+	conn2 := connIdleTimeout{
397
+		Conn:    connMock2,
398
+		tracker: suite.tracker,
399
+	}
400
+
401
+	connMock2.On("SetWriteDeadline", mock.Anything).Return(nil)
402
+	connMock2.On("Write", mock.Anything).Once().Return(5, nil)
403
+
404
+	_, _ = conn2.Write(make([]byte, 5))
405
+
406
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
407
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
408
+	suite.connMock.On("Read", mock.Anything).Once().Return(3, nil)
409
+
410
+	n, err := suite.conn.Read(make([]byte, 10))
411
+	suite.NoError(err)
412
+	suite.Equal(3, n)
413
+
414
+	connMock2.AssertExpectations(suite.T())
415
+}
416
+
417
+func (suite *ConnIdleTimeoutTestSuite) TestWriteOk() {
418
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
419
+	suite.connMock.On("Write", mock.Anything).Once().Return(5, nil)
420
+
421
+	n, err := suite.conn.Write(make([]byte, 5))
422
+	suite.NoError(err)
423
+	suite.Equal(5, n)
424
+}
425
+
426
+func (suite *ConnIdleTimeoutTestSuite) TestWriteErr() {
427
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
428
+	suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
429
+
430
+	n, err := suite.conn.Write(make([]byte, 5))
431
+	suite.True(errors.Is(err, io.EOF))
432
+	suite.Equal(0, n)
433
+}
434
+
294 435
 func TestConnTraffic(t *testing.T) {
295 436
 	t.Parallel()
296 437
 	suite.Run(t, &ConnTrafficTestSuite{})
@@ -305,3 +446,13 @@ func TestConnProxyProtocol(t *testing.T) {
305 446
 	t.Parallel()
306 447
 	suite.Run(t, &ConnProxyProtocolTestSuite{})
307 448
 }
449
+
450
+func TestIdleTracker(t *testing.T) {
451
+	t.Parallel()
452
+	suite.Run(t, &IdleTrackerTestSuite{})
453
+}
454
+
455
+func TestConnIdleTimeout(t *testing.T) {
456
+	t.Parallel()
457
+	suite.Run(t, &ConnIdleTimeoutTestSuite{})
458
+}

+ 6
- 2
mtglib/internal/dc/view.go Просмотреть файл

@@ -5,15 +5,19 @@ type dcView struct {
5 5
 }
6 6
 
7 7
 func (d dcView) getV4(dc int) []Addr {
8
-	addrs := d.publicConfigs.getV4(dc)
8
+	var addrs []Addr
9
+
9 10
 	addrs = append(addrs, defaultDCAddrSet.getV4(dc)...)
11
+	addrs = append(addrs, d.publicConfigs.getV4(dc)...)
10 12
 
11 13
 	return addrs
12 14
 }
13 15
 
14 16
 func (d dcView) getV6(dc int) []Addr {
15
-	addrs := d.publicConfigs.getV6(dc)
17
+	var addrs []Addr
18
+
16 19
 	addrs = append(addrs, defaultDCAddrSet.getV6(dc)...)
20
+	addrs = append(addrs, d.publicConfigs.getV6(dc)...)
17 21
 
18 22
 	return addrs
19 23
 }

+ 12
- 6
mtglib/internal/doppel/scout.go Просмотреть файл

@@ -61,23 +61,29 @@ func (s Scout) learn(ctx context.Context, url string) (ScoutResult, error) {
61 61
 		client.CloseIdleConnections()
62 62
 	}
63 63
 
64
-	if err != nil || len(results.data) == 0 {
64
+	if err != nil {
65 65
 		return ScoutResult{}, err
66 66
 	}
67 67
 
68
+	data, writeIndex := results.Snapshot()
69
+
70
+	if len(data) == 0 {
71
+		return ScoutResult{}, nil
72
+	}
73
+
68 74
 	var result ScoutResult
69 75
 
70 76
 	// Compute inter-record durations (existing logic).
71 77
 	lastTimestamp := time.Time{}
72 78
 
73
-	for i, v := range results.data {
79
+	for i, v := range data {
74 80
 		if v.recordType != tls.TypeApplicationData {
75 81
 			continue
76 82
 		}
77 83
 
78 84
 		if lastTimestamp.IsZero() {
79 85
 			if i > 0 {
80
-				lastTimestamp = results.data[i-1].timestamp
86
+				lastTimestamp = data[i-1].timestamp
81 87
 			} else {
82 88
 				lastTimestamp = v.timestamp
83 89
 			}
@@ -90,12 +96,12 @@ func (s Scout) learn(ctx context.Context, url string) (ScoutResult, error) {
90 96
 	// Compute cert size: sum of ApplicationData payload between CCS and
91 97
 	// the first client Write (which marks the end of server handshake).
92 98
 	seenCCS := false
93
-	boundary := results.writeIndex
99
+	boundary := writeIndex
94 100
 	if boundary < 0 {
95
-		boundary = len(results.data)
101
+		boundary = len(data)
96 102
 	}
97 103
 
98
-	for i, v := range results.data {
104
+	for i, v := range data {
99 105
 		if i >= boundary {
100 106
 			break
101 107
 		}

+ 20
- 1
mtglib/internal/doppel/scout_conn_collected.go Просмотреть файл

@@ -1,6 +1,10 @@
1 1
 package doppel
2 2
 
3
-import "time"
3
+import (
4
+	"slices"
5
+	"sync"
6
+	"time"
7
+)
4 8
 
5 9
 const (
6 10
 	ScoutConnCollectedPreallocSize = 100
@@ -13,23 +17,38 @@ type ScoutConnResult struct {
13 17
 }
14 18
 
15 19
 type ScoutConnCollected struct {
20
+	mu         sync.Mutex
16 21
 	data       []ScoutConnResult
17 22
 	writeIndex int // index at which client first wrote post-handshake data; -1 if not set
18 23
 }
19 24
 
20 25
 func (s *ScoutConnCollected) Add(record byte, payloadLen int) {
26
+	s.mu.Lock()
21 27
 	s.data = append(s.data, ScoutConnResult{
22 28
 		timestamp:  time.Now(),
23 29
 		recordType: record,
24 30
 		payloadLen: payloadLen,
25 31
 	})
32
+	s.mu.Unlock()
26 33
 }
27 34
 
28 35
 // MarkWrite records the current data length as the handshake boundary.
29 36
 func (s *ScoutConnCollected) MarkWrite() {
37
+	s.mu.Lock()
30 38
 	if s.writeIndex < 0 {
31 39
 		s.writeIndex = len(s.data)
32 40
 	}
41
+	s.mu.Unlock()
42
+}
43
+
44
+// Snapshot returns a copy of the collected data and the write index.
45
+func (s *ScoutConnCollected) Snapshot() ([]ScoutConnResult, int) {
46
+	s.mu.Lock()
47
+	snapshot := slices.Clone(s.data)
48
+	writeIndex := s.writeIndex
49
+	s.mu.Unlock()
50
+
51
+	return snapshot, writeIndex
33 52
 }
34 53
 
35 54
 func NewScoutConnCollected() *ScoutConnCollected {

+ 48
- 4
mtglib/internal/doppel/scout_conn_collected_test.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package doppel
2 2
 
3 3
 import (
4
+	"sync"
4 5
 	"testing"
5 6
 	"time"
6 7
 
@@ -16,8 +17,10 @@ func (suite *ScoutConnCollectedTestSuite) TestAddSingle() {
16 17
 	collected := NewScoutConnCollected()
17 18
 	collected.Add(tls.TypeApplicationData, 100)
18 19
 
19
-	suite.Len(collected.data, 1)
20
-	suite.Equal(byte(tls.TypeApplicationData), collected.data[0].recordType)
20
+	data, _ := collected.Snapshot()
21
+
22
+	suite.Len(data, 1)
23
+	suite.Equal(byte(tls.TypeApplicationData), data[0].recordType)
21 24
 }
22 25
 
23 26
 func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
@@ -31,11 +34,52 @@ func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
31 34
 	time.Sleep(time.Microsecond)
32 35
 	collected.Add(tls.TypeApplicationData, 100)
33 36
 
34
-	for i := 1; i < len(collected.data); i++ {
35
-		suite.True(collected.data[i].timestamp.After(collected.data[i-1].timestamp))
37
+	data, _ := collected.Snapshot()
38
+
39
+	for i := 1; i < len(data); i++ {
40
+		suite.True(data[i].timestamp.After(data[i-1].timestamp))
36 41
 	}
37 42
 }
38 43
 
44
+func (suite *ScoutConnCollectedTestSuite) TestConcurrentAddSnapshot() {
45
+	collected := NewScoutConnCollected()
46
+
47
+	var wg sync.WaitGroup
48
+
49
+	wg.Add(3)
50
+
51
+	go func() {
52
+		defer wg.Done()
53
+
54
+		for i := 0; i < 1000; i++ {
55
+			collected.Add(tls.TypeApplicationData, i)
56
+		}
57
+	}()
58
+
59
+	go func() {
60
+		defer wg.Done()
61
+
62
+		for i := 0; i < 100; i++ {
63
+			collected.MarkWrite()
64
+		}
65
+	}()
66
+
67
+	go func() {
68
+		defer wg.Done()
69
+
70
+		for i := 0; i < 1000; i++ {
71
+			// call Snapshot concurrently to exercise the lock under -race
72
+			collected.Snapshot() //nolint:errcheck
73
+		}
74
+	}()
75
+
76
+	wg.Wait()
77
+
78
+	data, writeIndex := collected.Snapshot()
79
+	suite.Len(data, 1000)
80
+	suite.GreaterOrEqual(writeIndex, 0)
81
+}
82
+
39 83
 func TestScoutConnCollected(t *testing.T) {
40 84
 	t.Parallel()
41 85
 	suite.Run(t, &ScoutConnCollectedTestSuite{})

+ 122
- 16
mtglib/internal/tls/fake/client_side.go Просмотреть файл

@@ -6,6 +6,7 @@ import (
6 6
 	"crypto/sha256"
7 7
 	"crypto/subtle"
8 8
 	"encoding/binary"
9
+	"errors"
9 10
 	"fmt"
10 11
 	"io"
11 12
 	"net"
@@ -23,6 +24,11 @@ const (
23 24
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
24 25
 
25 26
 	sniDNSNamesListType = 0
27
+
28
+	// maxContinuationRecords limits the number of continuation TLS records
29
+	// that reassembleTLSHandshake will read. This prevents resource exhaustion
30
+	// from adversarial fragmentation.
31
+	maxContinuationRecords = 10
26 32
 )
27 33
 
28 34
 var (
@@ -71,12 +77,18 @@ func ReadClientHelloMulti(
71 77
 	}
72 78
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
73 79
 
80
+	reassembled, err := reassembleTLSHandshake(conn)
81
+	if err != nil {
82
+		return nil, fmt.Errorf("cannot reassemble TLS records: %w", err)
83
+	}
84
+
74 85
 	handshakeCopyBuf := &bytes.Buffer{}
75
-	reader := io.TeeReader(conn, handshakeCopyBuf)
86
+	reader := io.TeeReader(reassembled, handshakeCopyBuf)
76 87
 
77
-	reader, err := parseTLSHeader(reader)
78
-	if err != nil {
79
-		return nil, fmt.Errorf("cannot parse tls header: %w", err)
88
+	// Skip the TLS record header (validated during reassembly).
89
+	// The header still flows through TeeReader into handshakeCopyBuf for HMAC.
90
+	if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil {
91
+		return nil, fmt.Errorf("cannot skip tls header: %w", err)
80 92
 	}
81 93
 
82 94
 	reader, err = parseHandshakeHeader(reader)
@@ -135,17 +147,30 @@ func ReadClientHelloMulti(
135 147
 	return nil, ErrBadDigest
136 148
 }
137 149
 
138
-func parseTLSHeader(r io.Reader) (io.Reader, error) {
139
-	// record_type(1) + version(2) + size(2)
140
-	//   16 - type is 0x16 (handshake record)
141
-	//   03 01 - protocol version is "3,1" (also known as TLS 1.0)
142
-	//   00 f8 - 0xF8 (248) bytes of handshake message follows
143
-	header := [1 + 2 + 2]byte{}
144
-
145
-	if _, err := io.ReadFull(r, header[:]); err != nil {
150
+// reassembleTLSHandshake reads one or more TLS records from conn,
151
+// validates the record type and version, and reassembles fragmented
152
+// handshake payloads into a single TLS record.
153
+//
154
+// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented
155
+// across multiple TLS records. DPI bypass tools like ByeDPI use this
156
+// to evade censorship.
157
+//
158
+// The returned buffer contains the full TLS record (header + payload)
159
+// so that callers can include the header in HMAC computation.
160
+func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) {
161
+	header := [tls.SizeHeader]byte{}
162
+
163
+	if _, err := io.ReadFull(conn, header[:]); err != nil {
146 164
 		return nil, fmt.Errorf("cannot read record header: %w", err)
147 165
 	}
148 166
 
167
+	length := int64(binary.BigEndian.Uint16(header[3:]))
168
+	payload := &bytes.Buffer{}
169
+
170
+	if _, err := io.CopyN(payload, conn, length); err != nil {
171
+		return nil, fmt.Errorf("cannot read record payload: %w", err)
172
+	}
173
+
149 174
 	if header[0] != tls.TypeHandshake {
150 175
 		return nil, fmt.Errorf("unexpected record type %#x", header[0])
151 176
 	}
@@ -154,12 +179,93 @@ func parseTLSHeader(r io.Reader) (io.Reader, error) {
154 179
 		return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
155 180
 	}
156 181
 
157
-	length := int64(binary.BigEndian.Uint16(header[3:]))
158
-	buf := &bytes.Buffer{}
182
+	// Reassemble fragmented payload. continuationCount caps the total
183
+	// number of continuation records across both phases below.
184
+	continuationCount := 0
159 185
 
160
-	_, err := io.CopyN(buf, r, length)
186
+	// Phase 1: read continuation records until we have at least the
187
+	// 4-byte handshake header (type + uint24 length) to determine the
188
+	// expected total size.
189
+	for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ {
190
+		prevLen := payload.Len()
161 191
 
162
-	return buf, err
192
+		if err := readContinuationRecord(conn, payload); err != nil {
193
+			payload.Truncate(prevLen) // discard partial data on error
194
+
195
+			if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
196
+				break // no more records — let downstream parsing handle what we have
197
+			}
198
+
199
+			return nil, err
200
+		}
201
+	}
202
+
203
+	// Phase 2: we know the expected handshake size — read remaining
204
+	// continuation records until the payload is complete.
205
+	if payload.Len() >= 4 {
206
+		p := payload.Bytes()
207
+		expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3]))
208
+
209
+		if expectedTotal > 0xFFFF {
210
+			return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal)
211
+		}
212
+
213
+		for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ {
214
+			if err := readContinuationRecord(conn, payload); err != nil {
215
+				return nil, err
216
+			}
217
+		}
218
+
219
+		if payload.Len() < expectedTotal {
220
+			return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records")
221
+		}
222
+
223
+		payload.Truncate(expectedTotal)
224
+	}
225
+
226
+	if payload.Len() > 0xFFFF {
227
+		return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len())
228
+	}
229
+
230
+	// Reconstruct a single TLS record with the reassembled payload.
231
+	result := &bytes.Buffer{}
232
+	result.Grow(tls.SizeHeader + payload.Len())
233
+	result.Write(header[:3])
234
+	binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails
235
+	result.Write(payload.Bytes())
236
+
237
+	return result, nil
238
+}
239
+
240
+// readContinuationRecord reads the next TLS record header and appends its
241
+// full payload to dst. It returns an error if the record is not a handshake
242
+// record.
243
+func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error {
244
+	nextHeader := [tls.SizeHeader]byte{}
245
+
246
+	if _, err := io.ReadFull(conn, nextHeader[:]); err != nil {
247
+		return fmt.Errorf("cannot read continuation record header: %w", err)
248
+	}
249
+
250
+	if nextHeader[0] != tls.TypeHandshake {
251
+		return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0])
252
+	}
253
+
254
+	if nextHeader[1] != 3 || nextHeader[2] != 1 {
255
+		return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2])
256
+	}
257
+
258
+	nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:]))
259
+
260
+	if nextLength == 0 {
261
+		return fmt.Errorf("zero-length continuation record")
262
+	}
263
+
264
+	if _, err := io.CopyN(dst, conn, nextLength); err != nil {
265
+		return fmt.Errorf("cannot read continuation record payload: %w", err)
266
+	}
267
+
268
+	return nil
163 269
 }
164 270
 
165 271
 func parseHandshakeHeader(r io.Reader) (io.Reader, error) {

+ 272
- 0
mtglib/internal/tls/fake/client_side_test.go Просмотреть файл

@@ -530,3 +530,275 @@ func TestReadClientHelloMulti(t *testing.T) {
530 530
 	t.Parallel()
531 531
 	suite.Run(t, &ReadClientHelloMultiTestSuite{})
532 532
 }
533
+
534
+// --- Fragmented TLS record tests ---
535
+
536
+// fragmentTLSRecord splits a single TLS record into n TLS records by
537
+// dividing the payload into roughly equal parts. Each part gets its own
538
+// TLS record header with the same record type and version.
539
+func fragmentTLSRecord(t testing.TB, full []byte, n int) []byte {
540
+	t.Helper()
541
+
542
+	recordType := full[0]
543
+	version := full[1:3]
544
+	payload := full[tls.SizeHeader:]
545
+
546
+	chunkSize := len(payload) / n
547
+	result := &bytes.Buffer{}
548
+
549
+	for i := 0; i < n; i++ {
550
+		start := i * chunkSize
551
+		end := start + chunkSize
552
+
553
+		if i == n-1 {
554
+			end = len(payload)
555
+		}
556
+
557
+		chunk := payload[start:end]
558
+		result.WriteByte(recordType)
559
+		result.Write(version)
560
+		require.NoError(t, binary.Write(result, binary.BigEndian, uint16(len(chunk))))
561
+		result.Write(chunk)
562
+	}
563
+
564
+	return result.Bytes()
565
+}
566
+
567
+// splitPayloadAt creates two TLS records from a single record by splitting
568
+// the payload at the given byte position.
569
+func splitPayloadAt(t testing.TB, full []byte, pos int) []byte {
570
+	t.Helper()
571
+
572
+	payload := full[tls.SizeHeader:]
573
+	buf := &bytes.Buffer{}
574
+
575
+	buf.WriteByte(tls.TypeHandshake)
576
+	buf.Write(full[1:3])
577
+	require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(pos)))
578
+	buf.Write(payload[:pos])
579
+
580
+	buf.WriteByte(tls.TypeHandshake)
581
+	buf.Write(full[1:3])
582
+	require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(len(payload)-pos)))
583
+	buf.Write(payload[pos:])
584
+
585
+	return buf.Bytes()
586
+}
587
+
588
+type ParseClientHelloFragmentedTestSuite struct {
589
+	suite.Suite
590
+
591
+	secret   mtglib.Secret
592
+	snapshot *clientHelloSnapshot
593
+}
594
+
595
+func (s *ParseClientHelloFragmentedTestSuite) SetupSuite() {
596
+	parsed, err := mtglib.ParseSecret(
597
+		"ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d",
598
+	)
599
+	require.NoError(s.T(), err)
600
+
601
+	s.secret = parsed
602
+
603
+	fileData, err := os.ReadFile("testdata/client-hello-ok-19dfe38384b9884b.json")
604
+	require.NoError(s.T(), err)
605
+
606
+	s.snapshot = &clientHelloSnapshot{}
607
+	require.NoError(s.T(), json.Unmarshal(fileData, s.snapshot))
608
+}
609
+
610
+func (s *ParseClientHelloFragmentedTestSuite) makeConn(data []byte) *parseClientHelloConnMock {
611
+	readBuf := &bytes.Buffer{}
612
+	readBuf.Write(data)
613
+
614
+	connMock := &parseClientHelloConnMock{
615
+		readBuf: readBuf,
616
+	}
617
+
618
+	connMock.
619
+		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
620
+		Twice().
621
+		Return(nil)
622
+
623
+	return connMock
624
+}
625
+
626
+func (s *ParseClientHelloFragmentedTestSuite) TestReassemblySuccess() {
627
+	full := s.snapshot.GetFull()
628
+
629
+	tests := []struct {
630
+		name string
631
+		data []byte
632
+	}{
633
+		{"two equal fragments", fragmentTLSRecord(s.T(), full, 2)},
634
+		{"three equal fragments", fragmentTLSRecord(s.T(), full, 3)},
635
+		{"single byte first fragment", splitPayloadAt(s.T(), full, 1)},
636
+		{"three byte first fragment", splitPayloadAt(s.T(), full, 3)},
637
+	}
638
+
639
+	for _, tt := range tests {
640
+		s.Run(tt.name, func() {
641
+			connMock := s.makeConn(tt.data)
642
+			defer connMock.AssertExpectations(s.T())
643
+
644
+			hello, err := fake.ReadClientHello(
645
+				connMock,
646
+				s.secret.Key[:],
647
+				s.secret.Host,
648
+				TolerateTime,
649
+			)
650
+			s.Require().NoError(err)
651
+
652
+			s.Equal(s.snapshot.GetRandom(), hello.Random[:])
653
+			s.Equal(s.snapshot.GetSessionID(), hello.SessionID)
654
+			s.Equal(uint16(s.snapshot.CipherSuite), hello.CipherSuite)
655
+		})
656
+	}
657
+}
658
+
659
+func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
660
+	full := s.snapshot.GetFull()
661
+	payload := full[tls.SizeHeader:]
662
+
663
+	tests := []struct {
664
+		name      string
665
+		buildData func() []byte
666
+		errMsg    string
667
+	}{
668
+		{
669
+			name: "wrong continuation record type",
670
+			buildData: func() []byte {
671
+				buf := &bytes.Buffer{}
672
+				buf.WriteByte(tls.TypeHandshake)
673
+				buf.Write(full[1:3])
674
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
675
+				buf.Write(payload[:10])
676
+				// Wrong type: application data instead of handshake
677
+				buf.WriteByte(tls.TypeApplicationData)
678
+				buf.Write(full[1:3])
679
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10)))
680
+				buf.Write(payload[10:])
681
+				return buf.Bytes()
682
+			},
683
+			errMsg: "unexpected continuation record type",
684
+		},
685
+		{
686
+			name: "too many continuation records",
687
+			buildData: func() []byte {
688
+				// Handshake header claiming 256 bytes, but we only send 1 byte per continuation
689
+				handshakePayload := []byte{0x01, 0x00, 0x01, 0x00}
690
+				buf := &bytes.Buffer{}
691
+				buf.WriteByte(tls.TypeHandshake)
692
+				buf.Write([]byte{3, 1})
693
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload))))
694
+				buf.Write(handshakePayload)
695
+				for range 11 {
696
+					buf.WriteByte(tls.TypeHandshake)
697
+					buf.Write([]byte{3, 1})
698
+					require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(1)))
699
+					buf.WriteByte(0xAB)
700
+				}
701
+				return buf.Bytes()
702
+			},
703
+			errMsg: "too many continuation records",
704
+		},
705
+		{
706
+			name: "zero-length continuation record",
707
+			buildData: func() []byte {
708
+				buf := &bytes.Buffer{}
709
+				buf.WriteByte(tls.TypeHandshake)
710
+				buf.Write(full[1:3])
711
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
712
+				buf.Write(payload[:10])
713
+				// Valid header but zero-length payload
714
+				buf.WriteByte(tls.TypeHandshake)
715
+				buf.Write(full[1:3])
716
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
717
+				return buf.Bytes()
718
+			},
719
+			errMsg: "zero-length continuation record",
720
+		},
721
+		{
722
+			name: "wrong continuation record version",
723
+			buildData: func() []byte {
724
+				buf := &bytes.Buffer{}
725
+				buf.WriteByte(tls.TypeHandshake)
726
+				buf.Write(full[1:3])
727
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
728
+				buf.Write(payload[:10])
729
+				// Wrong version: 3.3 instead of 3.1
730
+				buf.WriteByte(tls.TypeHandshake)
731
+				buf.Write([]byte{3, 3})
732
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10)))
733
+				buf.Write(payload[10:])
734
+				return buf.Bytes()
735
+			},
736
+			errMsg: "unexpected continuation record version",
737
+		},
738
+		{
739
+			name: "handshake message too large",
740
+			buildData: func() []byte {
741
+				// Handshake header claiming 0x010000 (65536) bytes — exceeds 0xFFFF limit
742
+				handshakePayload := []byte{0x01, 0x01, 0x00, 0x00}
743
+				buf := &bytes.Buffer{}
744
+				buf.WriteByte(tls.TypeHandshake)
745
+				buf.Write([]byte{3, 1})
746
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload))))
747
+				buf.Write(handshakePayload)
748
+				return buf.Bytes()
749
+			},
750
+			errMsg: "handshake message too large",
751
+		},
752
+		{
753
+			name: "truncated continuation record header",
754
+			buildData: func() []byte {
755
+				buf := &bytes.Buffer{}
756
+				buf.WriteByte(tls.TypeHandshake)
757
+				buf.Write(full[1:3])
758
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
759
+				buf.Write(payload[:10])
760
+				// Connection ends mid-header (only 2 bytes)
761
+				buf.WriteByte(tls.TypeHandshake)
762
+				buf.WriteByte(3)
763
+				return buf.Bytes()
764
+			},
765
+			errMsg: "cannot read continuation record header",
766
+		},
767
+		{
768
+			name: "truncated continuation record payload",
769
+			buildData: func() []byte {
770
+				buf := &bytes.Buffer{}
771
+				buf.WriteByte(tls.TypeHandshake)
772
+				buf.Write(full[1:3])
773
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
774
+				buf.Write(payload[:10])
775
+				// Claims 100 bytes but no payload follows
776
+				buf.WriteByte(tls.TypeHandshake)
777
+				buf.Write(full[1:3])
778
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
779
+				return buf.Bytes()
780
+			},
781
+			errMsg: "cannot read continuation record payload",
782
+		},
783
+	}
784
+
785
+	for _, tt := range tests {
786
+		s.Run(tt.name, func() {
787
+			connMock := s.makeConn(tt.buildData())
788
+			defer connMock.AssertExpectations(s.T())
789
+
790
+			_, err := fake.ReadClientHello(
791
+				connMock,
792
+				s.secret.Key[:],
793
+				s.secret.Host,
794
+				TolerateTime,
795
+			)
796
+			s.ErrorContains(err, tt.errMsg)
797
+		})
798
+	}
799
+}
800
+
801
+func TestParseClientHelloFragmented(t *testing.T) {
802
+	t.Parallel()
803
+	suite.Run(t, &ParseClientHelloFragmentedTestSuite{})
804
+}

+ 8
- 4
mtglib/proxy.go Просмотреть файл

@@ -111,11 +111,13 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
111 111
 		return
112 112
 	}
113 113
 
114
+	tracker := newIdleTracker(p.idleTimeout)
115
+
114 116
 	relay.Relay(
115 117
 		ctx,
116 118
 		ctx.logger.Named("relay"),
117
-		connIdleTimeout{Conn: ctx.telegramConn, timeout: p.idleTimeout},
118
-		newCountingConn(connIdleTimeout{Conn: ctx.clientConn, timeout: p.idleTimeout}, p.stats, ctx.secretName),
119
+		connIdleTimeout{Conn: ctx.telegramConn, tracker: tracker},
120
+		newCountingConn(connIdleTimeout{Conn: ctx.clientConn, tracker: tracker}, p.stats, ctx.secretName),
119 121
 	)
120 122
 }
121 123
 
@@ -330,11 +332,13 @@ func (p *Proxy) doDomainFronting(ctx *streamContext, conn *connRewind) {
330 332
 		stream:   p.eventStream,
331 333
 	}
332 334
 
335
+	tracker := newIdleTracker(p.idleTimeout)
336
+
333 337
 	relay.Relay(
334 338
 		ctx,
335 339
 		ctx.logger.Named("domain-fronting"),
336
-		connIdleTimeout{Conn: frontConn, timeout: p.idleTimeout},
337
-		connIdleTimeout{Conn: conn, timeout: p.idleTimeout},
340
+		connIdleTimeout{Conn: frontConn, tracker: tracker},
341
+		connIdleTimeout{Conn: conn, tracker: tracker},
338 342
 	)
339 343
 }
340 344
 

+ 1
- 1
mtglib/proxy_test.go Просмотреть файл

@@ -175,7 +175,7 @@ func (suite *ProxyTestSuite) TestHTTPSRequest() {
175 175
 	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
176 176
 
177 177
 	resp, err := client.Get(addr) //nolint: noctx
178
-	suite.NoError(err)
178
+	suite.Require().NoError(err)
179 179
 
180 180
 	defer resp.Body.Close() //nolint: errcheck
181 181
 

Загрузка…
Отмена
Сохранить