Kaynağa Gözat

Merge remote-tracking branch 'origin/stable' into v2

tags/v2.2.7
9seconds 1 ay önce
ebeveyn
işleme
1a81efcb6e

+ 1
- 1
.github/workflows/ci.yaml Dosyayı Görüntüle

@@ -122,7 +122,7 @@ jobs:
122 122
   artifacts:
123 123
     name: Build release artifacts
124 124
     runs-on: ubuntu-latest
125
-    timeout-minutes: 10
125
+    timeout-minutes: 20
126 126
     steps:
127 127
       - name: Checkout
128 128
         uses: actions/checkout@v6

+ 1
- 1
Dockerfile Dosyayı Görüntüle

@@ -33,7 +33,7 @@ RUN go mod download
33 33
 COPY . /app
34 34
 
35 35
 RUN set -x \
36
-  && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always)" \
36
+  && version="$(git describe --exact-match HEAD 2>/dev/null || git describe --tags --always 2>/dev/null || echo dev)" \
37 37
   && go build \
38 38
       -trimpath \
39 39
       -mod=readonly \

+ 2
- 0
README.md Dosyayı Görüntüle

@@ -29,6 +29,8 @@ are the most notable:
29 29
 * [Official](https://github.com/TelegramMessenger/MTProxy)
30 30
 * [Python](https://github.com/alexbers/mtprotoproxy)
31 31
 * [Erlang](https://github.com/seriyps/mtproto_proxy)
32
+* [Teleproxy (C)](https://github.com/teleproxy/teleproxy)
33
+* [mtproto.zig (Zig)](https://github.com/sleep3r/mtproto.zig)
32 34
 * [Telemt (Rust)](https://github.com/telemt/telemt)
33 35
 
34 36
 You can use any of these. They work great and all implementations have

+ 1
- 1
antireplay/stable_bloom_filter_test.go Dosyayı Görüntüle

@@ -12,7 +12,7 @@ type StableBloomFilterTestSuite struct {
12 12
 }
13 13
 
14 14
 func (suite *StableBloomFilterTestSuite) TestOp() {
15
-	filter := antireplay.NewStableBloomFilter(500, 0.001)
15
+	filter := antireplay.NewStableBloomFilter(100000, 0.001)
16 16
 
17 17
 	suite.False(filter.SeenBefore([]byte{1, 2, 3}))
18 18
 	suite.False(filter.SeenBefore([]byte{4, 5, 6}))

+ 4
- 4
internal/config/config.go Dosyayı Görüntüle

@@ -52,10 +52,10 @@ type Config struct {
52 52
 		Blocklist    ListConfig `json:"blocklist"`
53 53
 		Allowlist    ListConfig `json:"allowlist"`
54 54
 		Doppelganger struct {
55
-			URLs            []TypeHttpsURL  `json:"urls"`
56
-			Repeats         TypeConcurrency `json:"repeats_per_raid"`
57
-			UpdateEach      TypeDuration    `json:"raid_each"`
58
-			DRS             TypeBool        `json:"drs"`
55
+			URLs       []TypeHttpsURL  `json:"urls"`
56
+			Repeats    TypeConcurrency `json:"repeats_per_raid"`
57
+			UpdateEach TypeDuration    `json:"raid_each"`
58
+			DRS        TypeBool        `json:"drs"`
59 59
 		} `json:"doppelganger"`
60 60
 	} `json:"defense"`
61 61
 	Network struct {

+ 4
- 4
internal/config/parse.go Dosyayı Görüntüle

@@ -47,10 +47,10 @@ type tomlConfig struct {
47 47
 			UpdateEach          string   `toml:"update-each" json:"updateEach,omitempty"`
48 48
 		} `toml:"allowlist" json:"allowlist,omitempty"`
49 49
 		Doppelganger struct {
50
-			URLs            []string `toml:"urls" json:"urls,omitempty"`
51
-			Repeats         uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
52
-			UpdateEach      string   `toml:"raid-each" json:"raid_each,omitempty"`
53
-			DRS             bool     `toml:"drs" json:"drs,omitempty"`
50
+			URLs       []string `toml:"urls" json:"urls,omitempty"`
51
+			Repeats    uint     `toml:"repeats-per-raid" json:"repeats_per_raid,omitempty"`
52
+			UpdateEach string   `toml:"raid-each" json:"raid_each,omitempty"`
53
+			DRS        bool     `toml:"drs" json:"drs,omitempty"`
54 54
 		} `toml:"doppelganger" json:"doppelganger,omitempty"`
55 55
 	} `toml:"defense" json:"defense,omitempty"`
56 56
 	Network struct {

+ 16
- 16
mise.lock Dosyayı Görüntüle

@@ -82,40 +82,40 @@ checksum = "sha256:4932cfca5e75bf60fe1c576edf459e5e809e6644664a068185d64b84af3fa
82 82
 url = "https://github.com/golangci/golangci-lint/releases/download/v2.11.4/golangci-lint-2.11.4-windows-amd64.zip"
83 83
 
84 84
 [[tools.goreleaser]]
85
-version = "2.14.3"
85
+version = "2.15.2"
86 86
 backend = "aqua:goreleaser/goreleaser"
87 87
 
88 88
 [tools.goreleaser."platforms.linux-arm64"]
89
-checksum = "sha256:581a10e53c1176b3e81ee45cf531e02dbf899db0bc7b795669347df4276ce948"
90
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Linux_arm64.tar.gz"
89
+checksum = "sha256:5db66761a98f6693161e49e1a95d28d2673a892ba60cb4a5e16736cafd41c4c9"
90
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Linux_arm64.tar.gz"
91 91
 provenance = "cosign"
92 92
 
93 93
 [tools.goreleaser."platforms.linux-arm64-musl"]
94
-checksum = "sha256:581a10e53c1176b3e81ee45cf531e02dbf899db0bc7b795669347df4276ce948"
95
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Linux_arm64.tar.gz"
94
+checksum = "sha256:5db66761a98f6693161e49e1a95d28d2673a892ba60cb4a5e16736cafd41c4c9"
95
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Linux_arm64.tar.gz"
96 96
 provenance = "cosign"
97 97
 
98 98
 [tools.goreleaser."platforms.linux-x64"]
99
-checksum = "sha256:dc7faeeeb6da8bdfda788626263a4ae725892a8c7504b975c3234127d4a44579"
100
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Linux_x86_64.tar.gz"
99
+checksum = "sha256:0ebdbf0353aba566b969dde746cc4e4806f96c27aa2f3971b229a9df7611fedc"
100
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Linux_x86_64.tar.gz"
101 101
 provenance = "cosign"
102 102
 
103 103
 [tools.goreleaser."platforms.linux-x64-musl"]
104
-checksum = "sha256:dc7faeeeb6da8bdfda788626263a4ae725892a8c7504b975c3234127d4a44579"
105
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Linux_x86_64.tar.gz"
104
+checksum = "sha256:0ebdbf0353aba566b969dde746cc4e4806f96c27aa2f3971b229a9df7611fedc"
105
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Linux_x86_64.tar.gz"
106 106
 provenance = "cosign"
107 107
 
108 108
 [tools.goreleaser."platforms.macos-arm64"]
109
-checksum = "sha256:3507798489e107a78aff36b169de48148a335ac26eb3161608d905f3f3a957bd"
110
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Darwin_all.tar.gz"
111
-provenance = "cosign"
109
+checksum = "sha256:0e6bd67688ac949780bf1166813a91f89856898ef4c40d7d46c2c74ebaa4b9ee"
110
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Darwin_all.tar.gz"
111
+provenance = "github-attestations"
112 112
 
113 113
 [tools.goreleaser."platforms.macos-x64"]
114
-checksum = "sha256:3507798489e107a78aff36b169de48148a335ac26eb3161608d905f3f3a957bd"
115
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Darwin_all.tar.gz"
114
+checksum = "sha256:0e6bd67688ac949780bf1166813a91f89856898ef4c40d7d46c2c74ebaa4b9ee"
115
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Darwin_all.tar.gz"
116 116
 provenance = "cosign"
117 117
 
118 118
 [tools.goreleaser."platforms.windows-x64"]
119
-checksum = "sha256:3deea8ff471aa258a2d99f3e5302971d7028647ae8ddaf103257a8113e485a31"
120
-url = "https://github.com/goreleaser/goreleaser/releases/download/v2.14.3/goreleaser_Windows_x86_64.zip"
119
+checksum = "sha256:7459832946dbe122c144f8d7f87484d8572ca005b779310aa6bb03346e8de17a"
120
+url = "https://github.com/goreleaser/goreleaser/releases/download/v2.15.2/goreleaser_Windows_x86_64.zip"
121 121
 provenance = "cosign"

+ 14
- 17
mtglib/conns.go Dosyayı Görüntüle

@@ -3,6 +3,7 @@ package mtglib
3 3
 import (
4 4
 	"bytes"
5 5
 	"context"
6
+	"errors"
6 7
 	"fmt"
7 8
 	"io"
8 9
 	"net"
@@ -102,7 +103,7 @@ func newConnProxyProtocol(source, target essentials.Conn) *connProxyProtocol {
102 103
 // Both directions update the same timestamp so that activity in one direction
103 104
 // prevents the other (idle) direction from timing out.
104 105
 type idleTracker struct {
105
-	lastActive atomic.Int64 // unix nanos
106
+	lastActive atomic.Pointer[time.Time]
106 107
 	timeout    time.Duration
107 108
 }
108 109
 
@@ -114,13 +115,12 @@ func newIdleTracker(timeout time.Duration) *idleTracker {
114 115
 }
115 116
 
116 117
 func (t *idleTracker) touch() {
117
-	t.lastActive.Store(time.Now().UnixNano())
118
+	stamp := time.Now()
119
+	t.lastActive.Store(&stamp)
118 120
 }
119 121
 
120 122
 func (t *idleTracker) isIdle() bool {
121
-	last := time.Unix(0, t.lastActive.Load())
122
-
123
-	return time.Since(last) >= t.timeout
123
+	return time.Since(*t.lastActive.Load()) >= t.timeout
124 124
 }
125 125
 
126 126
 type connIdleTimeout struct {
@@ -130,25 +130,22 @@ type connIdleTimeout struct {
130 130
 }
131 131
 
132 132
 func (c connIdleTimeout) Read(b []byte) (int, error) {
133
+	var netErr net.Error
134
+
133 135
 	for {
134 136
 		c.SetReadDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
135 137
 
136 138
 		n, err := c.Conn.Read(b)
137
-		if n > 0 {
138
-			c.tracker.touch()
139
-
140
-			return n, err //nolint: wrapcheck
141
-		}
142
-
143
-		if err != nil {
144
-			if netErr, ok := err.(net.Error); ok && netErr.Timeout() && !c.tracker.isIdle() { //nolint: errorlint
145
-				continue
146
-			}
147 139
 
148
-			return 0, err //nolint: wrapcheck
140
+		switch {
141
+		case err == nil:
142
+			c.tracker.touch()
143
+			return n, nil
144
+		case errors.As(err, &netErr) && netErr.Timeout() && !c.tracker.isIdle():
145
+			continue
149 146
 		}
150 147
 
151
-		return 0, nil
148
+		return n, err
152 149
 	}
153 150
 }
154 151
 

+ 6
- 2
mtglib/internal/dc/view.go Dosyayı Görüntüle

@@ -5,15 +5,19 @@ type dcView struct {
5 5
 }
6 6
 
7 7
 func (d dcView) getV4(dc int) []Addr {
8
-	addrs := d.publicConfigs.getV4(dc)
8
+	var addrs []Addr
9
+
9 10
 	addrs = append(addrs, defaultDCAddrSet.getV4(dc)...)
11
+	addrs = append(addrs, d.publicConfigs.getV4(dc)...)
10 12
 
11 13
 	return addrs
12 14
 }
13 15
 
14 16
 func (d dcView) getV6(dc int) []Addr {
15
-	addrs := d.publicConfigs.getV6(dc)
17
+	var addrs []Addr
18
+
16 19
 	addrs = append(addrs, defaultDCAddrSet.getV6(dc)...)
20
+	addrs = append(addrs, d.publicConfigs.getV6(dc)...)
17 21
 
18 22
 	return addrs
19 23
 }

+ 12
- 6
mtglib/internal/doppel/scout.go Dosyayı Görüntüle

@@ -61,23 +61,29 @@ func (s Scout) learn(ctx context.Context, url string) (ScoutResult, error) {
61 61
 		client.CloseIdleConnections()
62 62
 	}
63 63
 
64
-	if err != nil || len(results.data) == 0 {
64
+	if err != nil {
65 65
 		return ScoutResult{}, err
66 66
 	}
67 67
 
68
+	data, writeIndex := results.Snapshot()
69
+
70
+	if len(data) == 0 {
71
+		return ScoutResult{}, nil
72
+	}
73
+
68 74
 	var result ScoutResult
69 75
 
70 76
 	// Compute inter-record durations (existing logic).
71 77
 	lastTimestamp := time.Time{}
72 78
 
73
-	for i, v := range results.data {
79
+	for i, v := range data {
74 80
 		if v.recordType != tls.TypeApplicationData {
75 81
 			continue
76 82
 		}
77 83
 
78 84
 		if lastTimestamp.IsZero() {
79 85
 			if i > 0 {
80
-				lastTimestamp = results.data[i-1].timestamp
86
+				lastTimestamp = data[i-1].timestamp
81 87
 			} else {
82 88
 				lastTimestamp = v.timestamp
83 89
 			}
@@ -90,12 +96,12 @@ func (s Scout) learn(ctx context.Context, url string) (ScoutResult, error) {
90 96
 	// Compute cert size: sum of ApplicationData payload between CCS and
91 97
 	// the first client Write (which marks the end of server handshake).
92 98
 	seenCCS := false
93
-	boundary := results.writeIndex
99
+	boundary := writeIndex
94 100
 	if boundary < 0 {
95
-		boundary = len(results.data)
101
+		boundary = len(data)
96 102
 	}
97 103
 
98
-	for i, v := range results.data {
104
+	for i, v := range data {
99 105
 		if i >= boundary {
100 106
 			break
101 107
 		}

+ 20
- 1
mtglib/internal/doppel/scout_conn_collected.go Dosyayı Görüntüle

@@ -1,6 +1,10 @@
1 1
 package doppel
2 2
 
3
-import "time"
3
+import (
4
+	"slices"
5
+	"sync"
6
+	"time"
7
+)
4 8
 
5 9
 const (
6 10
 	ScoutConnCollectedPreallocSize = 100
@@ -13,23 +17,38 @@ type ScoutConnResult struct {
13 17
 }
14 18
 
15 19
 type ScoutConnCollected struct {
20
+	mu         sync.Mutex
16 21
 	data       []ScoutConnResult
17 22
 	writeIndex int // index at which client first wrote post-handshake data; -1 if not set
18 23
 }
19 24
 
20 25
 func (s *ScoutConnCollected) Add(record byte, payloadLen int) {
26
+	s.mu.Lock()
21 27
 	s.data = append(s.data, ScoutConnResult{
22 28
 		timestamp:  time.Now(),
23 29
 		recordType: record,
24 30
 		payloadLen: payloadLen,
25 31
 	})
32
+	s.mu.Unlock()
26 33
 }
27 34
 
28 35
 // MarkWrite records the current data length as the handshake boundary.
29 36
 func (s *ScoutConnCollected) MarkWrite() {
37
+	s.mu.Lock()
30 38
 	if s.writeIndex < 0 {
31 39
 		s.writeIndex = len(s.data)
32 40
 	}
41
+	s.mu.Unlock()
42
+}
43
+
44
+// Snapshot returns a copy of the collected data and the write index.
45
+func (s *ScoutConnCollected) Snapshot() ([]ScoutConnResult, int) {
46
+	s.mu.Lock()
47
+	snapshot := slices.Clone(s.data)
48
+	writeIndex := s.writeIndex
49
+	s.mu.Unlock()
50
+
51
+	return snapshot, writeIndex
33 52
 }
34 53
 
35 54
 func NewScoutConnCollected() *ScoutConnCollected {

+ 48
- 4
mtglib/internal/doppel/scout_conn_collected_test.go Dosyayı Görüntüle

@@ -1,6 +1,7 @@
1 1
 package doppel
2 2
 
3 3
 import (
4
+	"sync"
4 5
 	"testing"
5 6
 	"time"
6 7
 
@@ -16,8 +17,10 @@ func (suite *ScoutConnCollectedTestSuite) TestAddSingle() {
16 17
 	collected := NewScoutConnCollected()
17 18
 	collected.Add(tls.TypeApplicationData, 100)
18 19
 
19
-	suite.Len(collected.data, 1)
20
-	suite.Equal(byte(tls.TypeApplicationData), collected.data[0].recordType)
20
+	data, _ := collected.Snapshot()
21
+
22
+	suite.Len(data, 1)
23
+	suite.Equal(byte(tls.TypeApplicationData), data[0].recordType)
21 24
 }
22 25
 
23 26
 func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
@@ -31,11 +34,52 @@ func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
31 34
 	time.Sleep(time.Microsecond)
32 35
 	collected.Add(tls.TypeApplicationData, 100)
33 36
 
34
-	for i := 1; i < len(collected.data); i++ {
35
-		suite.True(collected.data[i].timestamp.After(collected.data[i-1].timestamp))
37
+	data, _ := collected.Snapshot()
38
+
39
+	for i := 1; i < len(data); i++ {
40
+		suite.True(data[i].timestamp.After(data[i-1].timestamp))
36 41
 	}
37 42
 }
38 43
 
44
+func (suite *ScoutConnCollectedTestSuite) TestConcurrentAddSnapshot() {
45
+	collected := NewScoutConnCollected()
46
+
47
+	var wg sync.WaitGroup
48
+
49
+	wg.Add(3)
50
+
51
+	go func() {
52
+		defer wg.Done()
53
+
54
+		for i := 0; i < 1000; i++ {
55
+			collected.Add(tls.TypeApplicationData, i)
56
+		}
57
+	}()
58
+
59
+	go func() {
60
+		defer wg.Done()
61
+
62
+		for i := 0; i < 100; i++ {
63
+			collected.MarkWrite()
64
+		}
65
+	}()
66
+
67
+	go func() {
68
+		defer wg.Done()
69
+
70
+		for i := 0; i < 1000; i++ {
71
+			// call Snapshot concurrently to exercise the lock under -race
72
+			collected.Snapshot() //nolint:errcheck
73
+		}
74
+	}()
75
+
76
+	wg.Wait()
77
+
78
+	data, writeIndex := collected.Snapshot()
79
+	suite.Len(data, 1000)
80
+	suite.GreaterOrEqual(writeIndex, 0)
81
+}
82
+
39 83
 func TestScoutConnCollected(t *testing.T) {
40 84
 	t.Parallel()
41 85
 	suite.Run(t, &ScoutConnCollectedTestSuite{})

+ 7
- 69
mtglib/internal/tls/fake/client_side.go Dosyayı Görüntüle

@@ -11,8 +11,6 @@ import (
11 11
 	"net"
12 12
 	"slices"
13 13
 	"time"
14
-
15
-	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
16 14
 )
17 15
 
18 16
 const (
@@ -56,25 +54,17 @@ func ReadClientHello(
56 54
 	//  4. New digest should be all 0 except of last 4 bytes
57 55
 	//  5. Last 4 bytes are little endian uint32 of UNIX timestamp when
58 56
 	//     this message was created.
59
-	handshakeCopyBuf := &bytes.Buffer{}
60
-	reader := io.TeeReader(conn, handshakeCopyBuf)
61
-
62
-	reader, err := parseTLSHeader(reader)
63
-	if err != nil {
64
-		return nil, fmt.Errorf("cannot parse tls header: %w", err)
65
-	}
66
-
67
-	reader, err = parseHandshakeHeader(reader)
57
+	clientHelloCopy, handshakeReader, err := parseClientHello(conn)
68 58
 	if err != nil {
69
-		return nil, fmt.Errorf("cannot parse handshake header: %w", err)
59
+		return nil, fmt.Errorf("cannot read client hello: %w", err)
70 60
 	}
71 61
 
72
-	hello, err := parseHandshake(reader)
62
+	hello, err := parseHandshake(handshakeReader)
73 63
 	if err != nil {
74 64
 		return nil, fmt.Errorf("cannot parse handshake: %w", err)
75 65
 	}
76 66
 
77
-	sniHostnames, err := parseSNI(reader)
67
+	sniHostnames, err := parseSNI(handshakeReader)
78 68
 	if err != nil {
79 69
 		return nil, fmt.Errorf("cannot parse SNI: %w", err)
80 70
 	}
@@ -85,10 +75,10 @@ func ReadClientHello(
85 75
 
86 76
 	digest := hmac.New(sha256.New, secret)
87 77
 	// we write a copy of the handshake with client random all nullified.
88
-	digest.Write(handshakeCopyBuf.Next(RandomOffset))
89
-	handshakeCopyBuf.Next(RandomLen)
78
+	digest.Write(clientHelloCopy.Next(RandomOffset))
79
+	clientHelloCopy.Next(RandomLen)
90 80
 	digest.Write(emptyRandom[:])
91
-	digest.Write(handshakeCopyBuf.Bytes())
81
+	digest.Write(clientHelloCopy.Bytes())
92 82
 
93 83
 	computed := digest.Sum(nil)
94 84
 
@@ -110,58 +100,6 @@ func ReadClientHello(
110 100
 	return hello, nil
111 101
 }
112 102
 
113
-func parseTLSHeader(r io.Reader) (io.Reader, error) {
114
-	// record_type(1) + version(2) + size(2)
115
-	//   16 - type is 0x16 (handshake record)
116
-	//   03 01 - protocol version is "3,1" (also known as TLS 1.0)
117
-	//   00 f8 - 0xF8 (248) bytes of handshake message follows
118
-	header := [1 + 2 + 2]byte{}
119
-
120
-	if _, err := io.ReadFull(r, header[:]); err != nil {
121
-		return nil, fmt.Errorf("cannot read record header: %w", err)
122
-	}
123
-
124
-	if header[0] != tls.TypeHandshake {
125
-		return nil, fmt.Errorf("unexpected record type %#x", header[0])
126
-	}
127
-
128
-	if header[1] != 3 || header[2] != 1 {
129
-		return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
130
-	}
131
-
132
-	length := int64(binary.BigEndian.Uint16(header[3:]))
133
-	buf := &bytes.Buffer{}
134
-
135
-	_, err := io.CopyN(buf, r, length)
136
-
137
-	return buf, err
138
-}
139
-
140
-func parseHandshakeHeader(r io.Reader) (io.Reader, error) {
141
-	// type(1) + size(3 / uint24)
142
-	// 01 - handshake message type 0x01 (client hello)
143
-	// 00 00 f4 - 0xF4 (244) bytes of client hello data follows
144
-	header := [1 + 3]byte{}
145
-
146
-	if _, err := io.ReadFull(r, header[:]); err != nil {
147
-		return nil, fmt.Errorf("cannot read handshake header: %w", err)
148
-	}
149
-
150
-	if header[0] != TypeHandshakeClient {
151
-		return nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
152
-	}
153
-
154
-	// unfortunately there is not uint24 in golang, so we just reust header
155
-	header[0] = 0
156
-
157
-	length := int64(binary.BigEndian.Uint32(header[:]))
158
-	buf := &bytes.Buffer{}
159
-
160
-	_, err := io.CopyN(buf, r, length)
161
-
162
-	return buf, err
163
-}
164
-
165 103
 func parseHandshake(r io.Reader) (*ClientHello, error) {
166 104
 	//  A protocol version of "3,3" (meaning TLS 1.2) is given.
167 105
 	header := [2]byte{}

+ 272
- 0
mtglib/internal/tls/fake/client_side_test.go Dosyayı Görüntüle

@@ -3,8 +3,10 @@ package fake_test
3 3
 import (
4 4
 	"bytes"
5 5
 	"encoding/binary"
6
+	"encoding/json"
6 7
 	"errors"
7 8
 	"io"
9
+	"os"
8 10
 	"testing"
9 11
 	"time"
10 12
 
@@ -393,3 +395,273 @@ func TestParseClientHelloSNI(t *testing.T) {
393 395
 	t.Parallel()
394 396
 	suite.Run(t, &ParseClientHelloSNITestSuite{})
395 397
 }
398
+
399
+// fragmentTLSRecord splits a single TLS record into n TLS records by
400
+// dividing the payload into roughly equal parts. Each part gets its own
401
+// TLS record header with the same record type and version.
402
+func fragmentTLSRecord(t testing.TB, full []byte, n int) []byte {
403
+	t.Helper()
404
+
405
+	recordType := full[0]
406
+	version := full[1:3]
407
+	payload := full[tls.SizeHeader:]
408
+
409
+	chunkSize := len(payload) / n
410
+	result := &bytes.Buffer{}
411
+
412
+	for i := 0; i < n; i++ {
413
+		start := i * chunkSize
414
+		end := start + chunkSize
415
+
416
+		if i == n-1 {
417
+			end = len(payload)
418
+		}
419
+
420
+		chunk := payload[start:end]
421
+		result.WriteByte(recordType)
422
+		result.Write(version)
423
+		require.NoError(t, binary.Write(result, binary.BigEndian, uint16(len(chunk))))
424
+		result.Write(chunk)
425
+	}
426
+
427
+	return result.Bytes()
428
+}
429
+
430
+// splitPayloadAt creates two TLS records from a single record by splitting
431
+// the payload at the given byte position.
432
+func splitPayloadAt(t testing.TB, full []byte, pos int) []byte {
433
+	t.Helper()
434
+
435
+	payload := full[tls.SizeHeader:]
436
+	buf := &bytes.Buffer{}
437
+
438
+	buf.WriteByte(tls.TypeHandshake)
439
+	buf.Write(full[1:3])
440
+	require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(pos)))
441
+	buf.Write(payload[:pos])
442
+
443
+	buf.WriteByte(tls.TypeHandshake)
444
+	buf.Write(full[1:3])
445
+	require.NoError(t, binary.Write(buf, binary.BigEndian, uint16(len(payload)-pos)))
446
+	buf.Write(payload[pos:])
447
+
448
+	return buf.Bytes()
449
+}
450
+
451
+type ParseClientHelloFragmentedTestSuite struct {
452
+	suite.Suite
453
+
454
+	secret   mtglib.Secret
455
+	snapshot *clientHelloSnapshot
456
+}
457
+
458
+func (s *ParseClientHelloFragmentedTestSuite) SetupSuite() {
459
+	parsed, err := mtglib.ParseSecret(
460
+		"ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d",
461
+	)
462
+	require.NoError(s.T(), err)
463
+
464
+	s.secret = parsed
465
+
466
+	fileData, err := os.ReadFile("testdata/client-hello-ok-19dfe38384b9884b.json")
467
+	require.NoError(s.T(), err)
468
+
469
+	s.snapshot = &clientHelloSnapshot{}
470
+	require.NoError(s.T(), json.Unmarshal(fileData, s.snapshot))
471
+}
472
+
473
+func (s *ParseClientHelloFragmentedTestSuite) makeConn(data []byte) *parseClientHelloConnMock {
474
+	readBuf := &bytes.Buffer{}
475
+	readBuf.Write(data)
476
+
477
+	connMock := &parseClientHelloConnMock{
478
+		readBuf: readBuf,
479
+	}
480
+
481
+	connMock.
482
+		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
483
+		Twice().
484
+		Return(nil)
485
+
486
+	return connMock
487
+}
488
+
489
+func (s *ParseClientHelloFragmentedTestSuite) TestReassemblySuccess() {
490
+	full := s.snapshot.GetFull()
491
+
492
+	tests := []struct {
493
+		name string
494
+		data []byte
495
+	}{
496
+		{"two equal fragments", fragmentTLSRecord(s.T(), full, 2)},
497
+		{"three equal fragments", fragmentTLSRecord(s.T(), full, 3)},
498
+		{"single byte first fragment", splitPayloadAt(s.T(), full, 1)},
499
+		{"three byte first fragment", splitPayloadAt(s.T(), full, 3)},
500
+	}
501
+
502
+	for _, tt := range tests {
503
+		s.Run(tt.name, func() {
504
+			connMock := s.makeConn(tt.data)
505
+			defer connMock.AssertExpectations(s.T())
506
+
507
+			hello, err := fake.ReadClientHello(
508
+				connMock,
509
+				s.secret.Key[:],
510
+				s.secret.Host,
511
+				TolerateTime,
512
+			)
513
+			s.Require().NoError(err)
514
+
515
+			s.Equal(s.snapshot.GetRandom(), hello.Random[:])
516
+			s.Equal(s.snapshot.GetSessionID(), hello.SessionID)
517
+			s.Equal(uint16(s.snapshot.CipherSuite), hello.CipherSuite)
518
+		})
519
+	}
520
+}
521
+
522
+func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
523
+	full := s.snapshot.GetFull()
524
+	payload := full[tls.SizeHeader:]
525
+
526
+	tests := []struct {
527
+		name      string
528
+		buildData func() []byte
529
+		errMsg    string
530
+	}{
531
+		{
532
+			name: "wrong continuation record type",
533
+			buildData: func() []byte {
534
+				buf := &bytes.Buffer{}
535
+				buf.WriteByte(tls.TypeHandshake)
536
+				buf.Write(full[1:3])
537
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
538
+				buf.Write(payload[:10])
539
+				// Wrong type: application data instead of handshake
540
+				buf.WriteByte(tls.TypeApplicationData)
541
+				buf.Write(full[1:3])
542
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10)))
543
+				buf.Write(payload[10:])
544
+				return buf.Bytes()
545
+			},
546
+			errMsg: "unexpected record type",
547
+		},
548
+		{
549
+			name: "too many continuation records",
550
+			buildData: func() []byte {
551
+				// Handshake header claiming 256 bytes, but we only send 1 byte per continuation
552
+				handshakePayload := []byte{0x01, 0x00, 0x01, 0x00}
553
+				buf := &bytes.Buffer{}
554
+				buf.WriteByte(tls.TypeHandshake)
555
+				buf.Write([]byte{3, 1})
556
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload))))
557
+				buf.Write(handshakePayload)
558
+				for range 11 {
559
+					buf.WriteByte(tls.TypeHandshake)
560
+					buf.Write([]byte{3, 1})
561
+					require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(1)))
562
+					buf.WriteByte(0xAB)
563
+				}
564
+				return buf.Bytes()
565
+			},
566
+			errMsg: "too many fragments",
567
+		},
568
+		{
569
+			name: "zero-length continuation record",
570
+			buildData: func() []byte {
571
+				buf := &bytes.Buffer{}
572
+				buf.WriteByte(tls.TypeHandshake)
573
+				buf.Write(full[1:3])
574
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
575
+				buf.Write(payload[:10])
576
+				// Valid header but zero-length payload
577
+				buf.WriteByte(tls.TypeHandshake)
578
+				buf.Write(full[1:3])
579
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
580
+				return buf.Bytes()
581
+			},
582
+			errMsg: "cannot read record header",
583
+		},
584
+		{
585
+			name: "wrong continuation record version",
586
+			buildData: func() []byte {
587
+				buf := &bytes.Buffer{}
588
+				buf.WriteByte(tls.TypeHandshake)
589
+				buf.Write(full[1:3])
590
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
591
+				buf.Write(payload[:10])
592
+				// Wrong version: 3.3 instead of 3.1
593
+				buf.WriteByte(tls.TypeHandshake)
594
+				buf.Write([]byte{3, 3})
595
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(payload)-10)))
596
+				buf.Write(payload[10:])
597
+				return buf.Bytes()
598
+			},
599
+			errMsg: "unexpected protocol version",
600
+		},
601
+		{
602
+			name: "handshake message too large",
603
+			buildData: func() []byte {
604
+				// Handshake header claiming 0x010000 (65536) bytes — exceeds 0xFFFF limit
605
+				handshakePayload := []byte{0x01, 0x01, 0x00, 0x00}
606
+				buf := &bytes.Buffer{}
607
+				buf.WriteByte(tls.TypeHandshake)
608
+				buf.Write([]byte{3, 1})
609
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(len(handshakePayload))))
610
+				buf.Write(handshakePayload)
611
+				return buf.Bytes()
612
+			},
613
+			errMsg: "cannot read record header",
614
+		},
615
+		{
616
+			name: "truncated continuation record header",
617
+			buildData: func() []byte {
618
+				buf := &bytes.Buffer{}
619
+				buf.WriteByte(tls.TypeHandshake)
620
+				buf.Write(full[1:3])
621
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
622
+				buf.Write(payload[:10])
623
+				// Connection ends mid-header (only 2 bytes)
624
+				buf.WriteByte(tls.TypeHandshake)
625
+				buf.WriteByte(3)
626
+				return buf.Bytes()
627
+			},
628
+			errMsg: "cannot read record header",
629
+		},
630
+		{
631
+			name: "truncated continuation record payload",
632
+			buildData: func() []byte {
633
+				buf := &bytes.Buffer{}
634
+				buf.WriteByte(tls.TypeHandshake)
635
+				buf.Write(full[1:3])
636
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(10)))
637
+				buf.Write(payload[:10])
638
+				// Claims 100 bytes but no payload follows
639
+				buf.WriteByte(tls.TypeHandshake)
640
+				buf.Write(full[1:3])
641
+				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
642
+				return buf.Bytes()
643
+			},
644
+			errMsg: "EOF",
645
+		},
646
+	}
647
+
648
+	for _, tt := range tests {
649
+		s.Run(tt.name, func() {
650
+			connMock := s.makeConn(tt.buildData())
651
+			defer connMock.AssertExpectations(s.T())
652
+
653
+			_, err := fake.ReadClientHello(
654
+				connMock,
655
+				s.secret.Key[:],
656
+				s.secret.Host,
657
+				TolerateTime,
658
+			)
659
+			s.ErrorContains(err, tt.errMsg)
660
+		})
661
+	}
662
+}
663
+
664
+func TestParseClientHelloFragmented(t *testing.T) {
665
+	t.Parallel()
666
+	suite.Run(t, &ParseClientHelloFragmentedTestSuite{})
667
+}

+ 158
- 0
mtglib/internal/tls/fake/utils.go Dosyayı Görüntüle

@@ -0,0 +1,158 @@
1
+package fake
2
+
3
+import (
4
+	"bytes"
5
+	"encoding/binary"
6
+	"errors"
7
+	"fmt"
8
+	"io"
9
+
10
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
11
+)
12
+
13
+const (
14
+	maxFragmentsCount = 10
15
+)
16
+
17
+var ErrTooManyFragments = errors.New("too many fragments")
18
+
19
+// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1
20
+// client hello can be fragmented in a series of packets:
21
+//
22
+//	Bytes on the wire:
23
+//
24
+// 16 03 01 00 F8 01 00 00 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...]
25
+// ├─────────────┤├──────────────────────────────────────────────────────────────────┤
26
+//
27
+//	TLS record       Payload (248 bytes)
28
+//	header (5B)
29
+//
30
+//	16    = Handshake
31
+//	03 01 = TLS 1.0 (record layer version)
32
+//	00 F8 = 248 bytes follow
33
+//
34
+//	01       = ClientHello (handshake type)
35
+//	00 00 F4 = 244 bytes of handshake body
36
+//	03 03    = TLS 1.2 (actual protocol version)
37
+//	...rest of ClientHello...
38
+//
39
+// Fragmented record look like:
40
+//
41
+//	Record 1:
42
+//
43
+// 16 03 01 00 03 01 00 00
44
+// ├─────────────┤├──────┤
45
+//
46
+//	TLS header    3 bytes of payload
47
+//
48
+//	16    = Handshake
49
+//	03 01 = TLS 1.0
50
+//	00 03 = only 3 bytes follow
51
+//
52
+//	01       = ClientHello type
53
+//	00 00    = first 2 bytes of the uint24 length (INCOMPLETE!)
54
+//
55
+// Record 2:
56
+// 16 03 01 00 F5 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...]
57
+// ├─────────────┤├────────────────────────────────────────────────────────────┤
58
+//
59
+//	TLS header    remaining 245 bytes of payload
60
+//
61
+//	16    = Handshake
62
+//	03 01 = TLS 1.0
63
+//	00 F5 = 245 bytes follow
64
+//
65
+//	F4    = last byte of uint24 length (now complete: 00 00 F4 = 244)
66
+//	03 03 = TLS 1.2
67
+//	...rest of ClientHello continues...
68
+//
69
+// So it means that there could be a series of handshake packets of different
70
+// lengths. The goal of this function is to concatenate these fragments.
71
+type fragmentedHandshakeReader struct {
72
+	r             io.Reader
73
+	buf           bytes.Buffer
74
+	readFragments int
75
+}
76
+
77
+func (f *fragmentedHandshakeReader) Read(p []byte) (int, error) {
78
+	if n, err := f.buf.Read(p); err == nil {
79
+		return n, nil
80
+	}
81
+
82
+	f.buf.Reset()
83
+
84
+	for f.buf.Len() == 0 {
85
+		if f.readFragments > maxFragmentsCount {
86
+			return 0, ErrTooManyFragments
87
+		}
88
+
89
+		if err := f.parseNextFragment(); err != nil {
90
+			return 0, err
91
+		}
92
+
93
+		f.readFragments++
94
+	}
95
+
96
+	return f.buf.Read(p)
97
+}
98
+
99
+func (f *fragmentedHandshakeReader) parseNextFragment() error {
100
+	// record_type(1) + version(2) + size(2)
101
+	//   16 - type is 0x16 (handshake record)
102
+	//   03 01 - protocol version is "3,1" (also known as TLS 1.0)
103
+	//   00 f8 - 0xF8 (248) bytes of handshake message follows
104
+	header := [1 + 2 + 2]byte{}
105
+
106
+	if _, err := io.ReadFull(f.r, header[:]); err != nil {
107
+		return fmt.Errorf("cannot read record header: %w", err)
108
+	}
109
+
110
+	if header[0] != tls.TypeHandshake {
111
+		return fmt.Errorf("unexpected record type %#x", header[0])
112
+	}
113
+
114
+	if header[1] != 3 || header[2] != 1 {
115
+		return fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
116
+	}
117
+
118
+	length := int64(binary.BigEndian.Uint16(header[3:]))
119
+	_, err := io.CopyN(&f.buf, f.r, length)
120
+
121
+	return err
122
+}
123
+
124
+func parseClientHello(r io.Reader) (*bytes.Buffer, *bytes.Buffer, error) {
125
+	r = &fragmentedHandshakeReader{r: r}
126
+	header := [1 + 3]byte{}
127
+
128
+	if _, err := io.ReadFull(r, header[:]); err != nil {
129
+		return nil, nil, fmt.Errorf("cannot read handshake header: %w", err)
130
+	}
131
+
132
+	if header[0] != TypeHandshakeClient {
133
+		return nil, nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
134
+	}
135
+
136
+	// unfortunately there is not uint24 in golang, so we just reuse header
137
+	header[0] = 0
138
+	length := int64(binary.BigEndian.Uint32(header[:]))
139
+
140
+	clientHelloCopy := &bytes.Buffer{}
141
+	clientHelloCopy.Write([]byte{tls.TypeHandshake, 3, 1})
142
+	binary.Write( //nolint: errcheck
143
+		clientHelloCopy,
144
+		binary.BigEndian,
145
+		// 1 for handshake type
146
+		// 3 for handshake length
147
+		uint16(1+3+length),
148
+	)
149
+	clientHelloCopy.WriteByte(TypeHandshakeClient)
150
+	clientHelloCopy.Write(header[1:])
151
+
152
+	handshakeCopy := &bytes.Buffer{}
153
+	writer := io.MultiWriter(clientHelloCopy, handshakeCopy)
154
+
155
+	_, err := io.CopyN(writer, r, length)
156
+
157
+	return clientHelloCopy, handshakeCopy, err
158
+}

+ 0
- 1
mtglib/proxy_opts.go Dosyayı Görüntüle

@@ -160,7 +160,6 @@ type ProxyOpts struct {
160 160
 
161 161
 	// DoppelGangerDRS defines if TLS Dynamic Record Sizing is active.
162 162
 	DoppelGangerDRS bool
163
-
164 163
 }
165 164
 
166 165
 func (p ProxyOpts) valid() error {

+ 1
- 1
mtglib/proxy_test.go Dosyayı Görüntüle

@@ -175,7 +175,7 @@ func (suite *ProxyTestSuite) TestHTTPSRequest() {
175 175
 	addr := fmt.Sprintf("https://%s/headers", suite.ProxyAddress())
176 176
 
177 177
 	resp, err := client.Get(addr) //nolint: noctx
178
-	suite.NoError(err)
178
+	suite.Require().NoError(err)
179 179
 
180 180
 	defer resp.Body.Close() //nolint: errcheck
181 181
 

Loading…
İptal
Kaydet