Просмотр исходного кода

Merge pull request #433 from 9seconds/refactor-tls-fragmentation

Refactor TLS fragmenting
tags/v2.2.7^2^2
Sergei Arkhipov 1 месяц назад
Родитель
Сommit
dbced77566
Аккаунт пользователя с таким Email не найден

+ 7
- 175
mtglib/internal/tls/fake/client_side.go Просмотреть файл

6
 	"crypto/sha256"
6
 	"crypto/sha256"
7
 	"crypto/subtle"
7
 	"crypto/subtle"
8
 	"encoding/binary"
8
 	"encoding/binary"
9
-	"errors"
10
 	"fmt"
9
 	"fmt"
11
 	"io"
10
 	"io"
12
 	"net"
11
 	"net"
13
 	"slices"
12
 	"slices"
14
 	"time"
13
 	"time"
15
-
16
-	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
17
 )
14
 )
18
 
15
 
19
 const (
16
 const (
24
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
21
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
25
 
22
 
26
 	sniDNSNamesListType = 0
23
 	sniDNSNamesListType = 0
27
-
28
-	// maxContinuationRecords limits the number of continuation TLS records
29
-	// that reassembleTLSHandshake will read. This prevents resource exhaustion
30
-	// from adversarial fragmentation.
31
-	maxContinuationRecords = 10
32
 )
24
 )
33
 
25
 
34
 var (
26
 var (
62
 	//  4. New digest should be all 0 except of last 4 bytes
54
 	//  4. New digest should be all 0 except of last 4 bytes
63
 	//  5. Last 4 bytes are little endian uint32 of UNIX timestamp when
55
 	//  5. Last 4 bytes are little endian uint32 of UNIX timestamp when
64
 	//     this message was created.
56
 	//     this message was created.
65
-	reassembled, err := reassembleTLSHandshake(conn)
57
+	clientHelloCopy, handshakeReader, err := parseClientHello(conn)
66
 	if err != nil {
58
 	if err != nil {
67
-		return nil, fmt.Errorf("cannot reassemble TLS records: %w", err)
68
-	}
69
-
70
-	handshakeCopyBuf := &bytes.Buffer{}
71
-	reader := io.TeeReader(reassembled, handshakeCopyBuf)
72
-
73
-	// Skip the TLS record header (validated during reassembly).
74
-	// The header still flows through TeeReader into handshakeCopyBuf for HMAC.
75
-	if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil {
76
-		return nil, fmt.Errorf("cannot skip tls header: %w", err)
59
+		return nil, fmt.Errorf("cannot read client hello: %w", err)
77
 	}
60
 	}
78
 
61
 
79
-	reader, err = parseHandshakeHeader(reader)
80
-	if err != nil {
81
-		return nil, fmt.Errorf("cannot parse handshake header: %w", err)
82
-	}
83
-
84
-	hello, err := parseHandshake(reader)
62
+	hello, err := parseHandshake(handshakeReader)
85
 	if err != nil {
63
 	if err != nil {
86
 		return nil, fmt.Errorf("cannot parse handshake: %w", err)
64
 		return nil, fmt.Errorf("cannot parse handshake: %w", err)
87
 	}
65
 	}
88
 
66
 
89
-	sniHostnames, err := parseSNI(reader)
67
+	sniHostnames, err := parseSNI(handshakeReader)
90
 	if err != nil {
68
 	if err != nil {
91
 		return nil, fmt.Errorf("cannot parse SNI: %w", err)
69
 		return nil, fmt.Errorf("cannot parse SNI: %w", err)
92
 	}
70
 	}
97
 
75
 
98
 	digest := hmac.New(sha256.New, secret)
76
 	digest := hmac.New(sha256.New, secret)
99
 	// we write a copy of the handshake with client random all nullified.
77
 	// we write a copy of the handshake with client random all nullified.
100
-	digest.Write(handshakeCopyBuf.Next(RandomOffset))
101
-	handshakeCopyBuf.Next(RandomLen)
78
+	digest.Write(clientHelloCopy.Next(RandomOffset))
79
+	clientHelloCopy.Next(RandomLen)
102
 	digest.Write(emptyRandom[:])
80
 	digest.Write(emptyRandom[:])
103
-	digest.Write(handshakeCopyBuf.Bytes())
81
+	digest.Write(clientHelloCopy.Bytes())
104
 
82
 
105
 	computed := digest.Sum(nil)
83
 	computed := digest.Sum(nil)
106
 
84
 
122
 	return hello, nil
100
 	return hello, nil
123
 }
101
 }
124
 
102
 
125
-// reassembleTLSHandshake reads one or more TLS records from conn,
126
-// validates the record type and version, and reassembles fragmented
127
-// handshake payloads into a single TLS record.
128
-//
129
-// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented
130
-// across multiple TLS records. DPI bypass tools like ByeDPI use this
131
-// to evade censorship.
132
-//
133
-// The returned buffer contains the full TLS record (header + payload)
134
-// so that callers can include the header in HMAC computation.
135
-func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) {
136
-	header := [tls.SizeHeader]byte{}
137
-
138
-	if _, err := io.ReadFull(conn, header[:]); err != nil {
139
-		return nil, fmt.Errorf("cannot read record header: %w", err)
140
-	}
141
-
142
-	length := int64(binary.BigEndian.Uint16(header[3:]))
143
-	payload := &bytes.Buffer{}
144
-
145
-	if _, err := io.CopyN(payload, conn, length); err != nil {
146
-		return nil, fmt.Errorf("cannot read record payload: %w", err)
147
-	}
148
-
149
-	if header[0] != tls.TypeHandshake {
150
-		return nil, fmt.Errorf("unexpected record type %#x", header[0])
151
-	}
152
-
153
-	if header[1] != 3 || header[2] != 1 {
154
-		return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
155
-	}
156
-
157
-	// Reassemble fragmented payload. continuationCount caps the total
158
-	// number of continuation records across both phases below.
159
-	continuationCount := 0
160
-
161
-	// Phase 1: read continuation records until we have at least the
162
-	// 4-byte handshake header (type + uint24 length) to determine the
163
-	// expected total size.
164
-	for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ {
165
-		prevLen := payload.Len()
166
-
167
-		if err := readContinuationRecord(conn, payload); err != nil {
168
-			payload.Truncate(prevLen) // discard partial data on error
169
-
170
-			if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
171
-				break // no more records — let downstream parsing handle what we have
172
-			}
173
-
174
-			return nil, err
175
-		}
176
-	}
177
-
178
-	// Phase 2: we know the expected handshake size — read remaining
179
-	// continuation records until the payload is complete.
180
-	if payload.Len() >= 4 {
181
-		p := payload.Bytes()
182
-		expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3]))
183
-
184
-		if expectedTotal > 0xFFFF {
185
-			return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal)
186
-		}
187
-
188
-		for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ {
189
-			if err := readContinuationRecord(conn, payload); err != nil {
190
-				return nil, err
191
-			}
192
-		}
193
-
194
-		if payload.Len() < expectedTotal {
195
-			return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records")
196
-		}
197
-
198
-		payload.Truncate(expectedTotal)
199
-	}
200
-
201
-	if payload.Len() > 0xFFFF {
202
-		return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len())
203
-	}
204
-
205
-	// Reconstruct a single TLS record with the reassembled payload.
206
-	result := &bytes.Buffer{}
207
-	result.Grow(tls.SizeHeader + payload.Len())
208
-	result.Write(header[:3])
209
-	binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails
210
-	result.Write(payload.Bytes())
211
-
212
-	return result, nil
213
-}
214
-
215
-// readContinuationRecord reads the next TLS record header and appends its
216
-// full payload to dst. It returns an error if the record is not a handshake
217
-// record.
218
-func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error {
219
-	nextHeader := [tls.SizeHeader]byte{}
220
-
221
-	if _, err := io.ReadFull(conn, nextHeader[:]); err != nil {
222
-		return fmt.Errorf("cannot read continuation record header: %w", err)
223
-	}
224
-
225
-	if nextHeader[0] != tls.TypeHandshake {
226
-		return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0])
227
-	}
228
-
229
-	if nextHeader[1] != 3 || nextHeader[2] != 1 {
230
-		return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2])
231
-	}
232
-
233
-	nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:]))
234
-
235
-	if nextLength == 0 {
236
-		return fmt.Errorf("zero-length continuation record")
237
-	}
238
-
239
-	if _, err := io.CopyN(dst, conn, nextLength); err != nil {
240
-		return fmt.Errorf("cannot read continuation record payload: %w", err)
241
-	}
242
-
243
-	return nil
244
-}
245
-
246
-func parseHandshakeHeader(r io.Reader) (io.Reader, error) {
247
-	// type(1) + size(3 / uint24)
248
-	// 01 - handshake message type 0x01 (client hello)
249
-	// 00 00 f4 - 0xF4 (244) bytes of client hello data follows
250
-	header := [1 + 3]byte{}
251
-
252
-	if _, err := io.ReadFull(r, header[:]); err != nil {
253
-		return nil, fmt.Errorf("cannot read handshake header: %w", err)
254
-	}
255
-
256
-	if header[0] != TypeHandshakeClient {
257
-		return nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
258
-	}
259
-
260
-	// unfortunately there is not uint24 in golang, so we just reust header
261
-	header[0] = 0
262
-
263
-	length := int64(binary.BigEndian.Uint32(header[:]))
264
-	buf := &bytes.Buffer{}
265
-
266
-	_, err := io.CopyN(buf, r, length)
267
-
268
-	return buf, err
269
-}
270
-
271
 func parseHandshake(r io.Reader) (*ClientHello, error) {
103
 func parseHandshake(r io.Reader) (*ClientHello, error) {
272
 	//  A protocol version of "3,3" (meaning TLS 1.2) is given.
104
 	//  A protocol version of "3,3" (meaning TLS 1.2) is given.
273
 	header := [2]byte{}
105
 	header := [2]byte{}

+ 7
- 7
mtglib/internal/tls/fake/client_side_test.go Просмотреть файл

543
 				buf.Write(payload[10:])
543
 				buf.Write(payload[10:])
544
 				return buf.Bytes()
544
 				return buf.Bytes()
545
 			},
545
 			},
546
-			errMsg: "unexpected continuation record type",
546
+			errMsg: "unexpected record type",
547
 		},
547
 		},
548
 		{
548
 		{
549
 			name: "too many continuation records",
549
 			name: "too many continuation records",
563
 				}
563
 				}
564
 				return buf.Bytes()
564
 				return buf.Bytes()
565
 			},
565
 			},
566
-			errMsg: "too many continuation records",
566
+			errMsg: "too many fragments",
567
 		},
567
 		},
568
 		{
568
 		{
569
 			name: "zero-length continuation record",
569
 			name: "zero-length continuation record",
579
 				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
579
 				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
580
 				return buf.Bytes()
580
 				return buf.Bytes()
581
 			},
581
 			},
582
-			errMsg: "zero-length continuation record",
582
+			errMsg: "cannot read record header",
583
 		},
583
 		},
584
 		{
584
 		{
585
 			name: "wrong continuation record version",
585
 			name: "wrong continuation record version",
596
 				buf.Write(payload[10:])
596
 				buf.Write(payload[10:])
597
 				return buf.Bytes()
597
 				return buf.Bytes()
598
 			},
598
 			},
599
-			errMsg: "unexpected continuation record version",
599
+			errMsg: "unexpected protocol version",
600
 		},
600
 		},
601
 		{
601
 		{
602
 			name: "handshake message too large",
602
 			name: "handshake message too large",
610
 				buf.Write(handshakePayload)
610
 				buf.Write(handshakePayload)
611
 				return buf.Bytes()
611
 				return buf.Bytes()
612
 			},
612
 			},
613
-			errMsg: "handshake message too large",
613
+			errMsg: "cannot read record header",
614
 		},
614
 		},
615
 		{
615
 		{
616
 			name: "truncated continuation record header",
616
 			name: "truncated continuation record header",
625
 				buf.WriteByte(3)
625
 				buf.WriteByte(3)
626
 				return buf.Bytes()
626
 				return buf.Bytes()
627
 			},
627
 			},
628
-			errMsg: "cannot read continuation record header",
628
+			errMsg: "cannot read record header",
629
 		},
629
 		},
630
 		{
630
 		{
631
 			name: "truncated continuation record payload",
631
 			name: "truncated continuation record payload",
641
 				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
641
 				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
642
 				return buf.Bytes()
642
 				return buf.Bytes()
643
 			},
643
 			},
644
-			errMsg: "cannot read continuation record payload",
644
+			errMsg: "EOF",
645
 		},
645
 		},
646
 	}
646
 	}
647
 
647
 

+ 158
- 0
mtglib/internal/tls/fake/utils.go Просмотреть файл

1
+package fake
2
+
3
+import (
4
+	"bytes"
5
+	"encoding/binary"
6
+	"errors"
7
+	"fmt"
8
+	"io"
9
+
10
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
11
+)
12
+
13
+const (
14
+	maxFragmentsCount = 10
15
+)
16
+
17
+var ErrTooManyFragments = errors.New("too many fragments")
18
+
19
+// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1
20
+// client hello can be fragmented in a series of packets:
21
+//
22
+//	Bytes on the wire:
23
+//
24
+// 16 03 01 00 F8 01 00 00 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...]
25
+// ├─────────────┤├──────────────────────────────────────────────────────────────────┤
26
+//
27
+//	TLS record       Payload (248 bytes)
28
+//	header (5B)
29
+//
30
+//	16    = Handshake
31
+//	03 01 = TLS 1.0 (record layer version)
32
+//	00 F8 = 248 bytes follow
33
+//
34
+//	01       = ClientHello (handshake type)
35
+//	00 00 F4 = 244 bytes of handshake body
36
+//	03 03    = TLS 1.2 (actual protocol version)
37
+//	...rest of ClientHello...
38
+//
39
+// Fragmented record look like:
40
+//
41
+//	Record 1:
42
+//
43
+// 16 03 01 00 03 01 00 00
44
+// ├─────────────┤├──────┤
45
+//
46
+//	TLS header    3 bytes of payload
47
+//
48
+//	16    = Handshake
49
+//	03 01 = TLS 1.0
50
+//	00 03 = only 3 bytes follow
51
+//
52
+//	01       = ClientHello type
53
+//	00 00    = first 2 bytes of the uint24 length (INCOMPLETE!)
54
+//
55
+// Record 2:
56
+// 16 03 01 00 F5 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...]
57
+// ├─────────────┤├────────────────────────────────────────────────────────────┤
58
+//
59
+//	TLS header    remaining 245 bytes of payload
60
+//
61
+//	16    = Handshake
62
+//	03 01 = TLS 1.0
63
+//	00 F5 = 245 bytes follow
64
+//
65
+//	F4    = last byte of uint24 length (now complete: 00 00 F4 = 244)
66
+//	03 03 = TLS 1.2
67
+//	...rest of ClientHello continues...
68
+//
69
+// So it means that there could be a series of handshake packets of different
70
+// lengths. The goal of this function is to concatenate these fragments.
71
+type fragmentedHandshakeReader struct {
72
+	r             io.Reader
73
+	buf           bytes.Buffer
74
+	readFragments int
75
+}
76
+
77
+func (f *fragmentedHandshakeReader) Read(p []byte) (int, error) {
78
+	if n, err := f.buf.Read(p); err == nil {
79
+		return n, nil
80
+	}
81
+
82
+	f.buf.Reset()
83
+
84
+	for f.buf.Len() == 0 {
85
+		if f.readFragments > maxFragmentsCount {
86
+			return 0, ErrTooManyFragments
87
+		}
88
+
89
+		if err := f.parseNextFragment(); err != nil {
90
+			return 0, err
91
+		}
92
+
93
+		f.readFragments++
94
+	}
95
+
96
+	return f.buf.Read(p)
97
+}
98
+
99
+func (f *fragmentedHandshakeReader) parseNextFragment() error {
100
+	// record_type(1) + version(2) + size(2)
101
+	//   16 - type is 0x16 (handshake record)
102
+	//   03 01 - protocol version is "3,1" (also known as TLS 1.0)
103
+	//   00 f8 - 0xF8 (248) bytes of handshake message follows
104
+	header := [1 + 2 + 2]byte{}
105
+
106
+	if _, err := io.ReadFull(f.r, header[:]); err != nil {
107
+		return fmt.Errorf("cannot read record header: %w", err)
108
+	}
109
+
110
+	if header[0] != tls.TypeHandshake {
111
+		return fmt.Errorf("unexpected record type %#x", header[0])
112
+	}
113
+
114
+	if header[1] != 3 || header[2] != 1 {
115
+		return fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
116
+	}
117
+
118
+	length := int64(binary.BigEndian.Uint16(header[3:]))
119
+	_, err := io.CopyN(&f.buf, f.r, length)
120
+
121
+	return err
122
+}
123
+
124
+func parseClientHello(r io.Reader) (*bytes.Buffer, *bytes.Buffer, error) {
125
+	r = &fragmentedHandshakeReader{r: r}
126
+	header := [1 + 3]byte{}
127
+
128
+	if _, err := io.ReadFull(r, header[:]); err != nil {
129
+		return nil, nil, fmt.Errorf("cannot read handshake header: %w", err)
130
+	}
131
+
132
+	if header[0] != TypeHandshakeClient {
133
+		return nil, nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
134
+	}
135
+
136
+	// unfortunately there is not uint24 in golang, so we just reuse header
137
+	header[0] = 0
138
+	length := int64(binary.BigEndian.Uint32(header[:]))
139
+
140
+	clientHelloCopy := &bytes.Buffer{}
141
+	clientHelloCopy.Write([]byte{tls.TypeHandshake, 3, 1})
142
+	binary.Write( //nolint: errcheck
143
+		clientHelloCopy,
144
+		binary.BigEndian,
145
+		// 1 for handshake type
146
+		// 3 for handshake length
147
+		uint16(1+3+length),
148
+	)
149
+	clientHelloCopy.WriteByte(TypeHandshakeClient)
150
+	clientHelloCopy.Write(header[1:])
151
+
152
+	handshakeCopy := &bytes.Buffer{}
153
+	writer := io.MultiWriter(clientHelloCopy, handshakeCopy)
154
+
155
+	_, err := io.CopyN(writer, r, length)
156
+
157
+	return clientHelloCopy, handshakeCopy, err
158
+}

Загрузка…
Отмена
Сохранить