Преглед изворни кода

Merge upstream: refactor TLS fragmentation (9seconds/mtg#433)

Upstream replaced our reassembleTLSHandshake with a cleaner
fragmentedHandshakeReader + parseClientHello in utils.go.
Adapted ReadClientHelloMulti to use the new API.
pull/450/head
Alexey Dolotov пре 1 месец
родитељ
комит
33a5cd3b24

+ 6
- 174
mtglib/internal/tls/fake/client_side.go Прегледај датотеку

@@ -6,14 +6,11 @@ import (
6 6
 	"crypto/sha256"
7 7
 	"crypto/subtle"
8 8
 	"encoding/binary"
9
-	"errors"
10 9
 	"fmt"
11 10
 	"io"
12 11
 	"net"
13 12
 	"slices"
14 13
 	"time"
15
-
16
-	"github.com/dolonet/mtg-multi/mtglib/internal/tls"
17 14
 )
18 15
 
19 16
 const (
@@ -24,11 +21,6 @@ const (
24 21
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
25 22
 
26 23
 	sniDNSNamesListType = 0
27
-
28
-	// maxContinuationRecords limits the number of continuation TLS records
29
-	// that reassembleTLSHandshake will read. This prevents resource exhaustion
30
-	// from adversarial fragmentation.
31
-	maxContinuationRecords = 10
32 24
 )
33 25
 
34 26
 var (
@@ -77,31 +69,17 @@ func ReadClientHelloMulti(
77 69
 	}
78 70
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
79 71
 
80
-	reassembled, err := reassembleTLSHandshake(conn)
72
+	clientHelloCopy, handshakeReader, err := parseClientHello(conn)
81 73
 	if err != nil {
82
-		return nil, fmt.Errorf("cannot reassemble TLS records: %w", err)
83
-	}
84
-
85
-	handshakeCopyBuf := &bytes.Buffer{}
86
-	reader := io.TeeReader(reassembled, handshakeCopyBuf)
87
-
88
-	// Skip the TLS record header (validated during reassembly).
89
-	// The header still flows through TeeReader into handshakeCopyBuf for HMAC.
90
-	if _, err = io.CopyN(io.Discard, reader, tls.SizeHeader); err != nil {
91
-		return nil, fmt.Errorf("cannot skip tls header: %w", err)
74
+		return nil, fmt.Errorf("cannot read client hello: %w", err)
92 75
 	}
93 76
 
94
-	reader, err = parseHandshakeHeader(reader)
95
-	if err != nil {
96
-		return nil, fmt.Errorf("cannot parse handshake header: %w", err)
97
-	}
98
-
99
-	hello, err := parseHandshake(reader)
77
+	hello, err := parseHandshake(handshakeReader)
100 78
 	if err != nil {
101 79
 		return nil, fmt.Errorf("cannot parse handshake: %w", err)
102 80
 	}
103 81
 
104
-	sniHostnames, err := parseSNI(reader)
82
+	sniHostnames, err := parseSNI(handshakeReader)
105 83
 	if err != nil {
106 84
 		return nil, fmt.Errorf("cannot parse SNI: %w", err)
107 85
 	}
@@ -110,8 +88,8 @@ func ReadClientHelloMulti(
110 88
 		return nil, fmt.Errorf("cannot find %s in %v", hostname, sniHostnames)
111 89
 	}
112 90
 
113
-	// Save the handshake bytes so we can reuse them for each secret attempt.
114
-	handshakeBytes := handshakeCopyBuf.Bytes()
91
+	// Save the full client hello bytes so we can replay them for each secret.
92
+	handshakeBytes := clientHelloCopy.Bytes()
115 93
 
116 94
 	for idx, secret := range secrets {
117 95
 		digest := hmac.New(sha256.New, secret)
@@ -147,152 +125,6 @@ func ReadClientHelloMulti(
147 125
 	return nil, ErrBadDigest
148 126
 }
149 127
 
150
-// reassembleTLSHandshake reads one or more TLS records from conn,
151
-// validates the record type and version, and reassembles fragmented
152
-// handshake payloads into a single TLS record.
153
-//
154
-// Per RFC 5246 Section 6.2.1, handshake messages may be fragmented
155
-// across multiple TLS records. DPI bypass tools like ByeDPI use this
156
-// to evade censorship.
157
-//
158
-// The returned buffer contains the full TLS record (header + payload)
159
-// so that callers can include the header in HMAC computation.
160
-func reassembleTLSHandshake(conn io.Reader) (*bytes.Buffer, error) {
161
-	header := [tls.SizeHeader]byte{}
162
-
163
-	if _, err := io.ReadFull(conn, header[:]); err != nil {
164
-		return nil, fmt.Errorf("cannot read record header: %w", err)
165
-	}
166
-
167
-	length := int64(binary.BigEndian.Uint16(header[3:]))
168
-	payload := &bytes.Buffer{}
169
-
170
-	if _, err := io.CopyN(payload, conn, length); err != nil {
171
-		return nil, fmt.Errorf("cannot read record payload: %w", err)
172
-	}
173
-
174
-	if header[0] != tls.TypeHandshake {
175
-		return nil, fmt.Errorf("unexpected record type %#x", header[0])
176
-	}
177
-
178
-	if header[1] != 3 || header[2] != 1 {
179
-		return nil, fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
180
-	}
181
-
182
-	// Reassemble fragmented payload. continuationCount caps the total
183
-	// number of continuation records across both phases below.
184
-	continuationCount := 0
185
-
186
-	// Phase 1: read continuation records until we have at least the
187
-	// 4-byte handshake header (type + uint24 length) to determine the
188
-	// expected total size.
189
-	for ; payload.Len() < 4 && continuationCount < maxContinuationRecords; continuationCount++ {
190
-		prevLen := payload.Len()
191
-
192
-		if err := readContinuationRecord(conn, payload); err != nil {
193
-			payload.Truncate(prevLen) // discard partial data on error
194
-
195
-			if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
196
-				break // no more records — let downstream parsing handle what we have
197
-			}
198
-
199
-			return nil, err
200
-		}
201
-	}
202
-
203
-	// Phase 2: we know the expected handshake size — read remaining
204
-	// continuation records until the payload is complete.
205
-	if payload.Len() >= 4 {
206
-		p := payload.Bytes()
207
-		expectedTotal := 4 + (int(p[1])<<16 | int(p[2])<<8 | int(p[3]))
208
-
209
-		if expectedTotal > 0xFFFF {
210
-			return nil, fmt.Errorf("handshake message too large: %d bytes", expectedTotal)
211
-		}
212
-
213
-		for ; payload.Len() < expectedTotal && continuationCount < maxContinuationRecords; continuationCount++ {
214
-			if err := readContinuationRecord(conn, payload); err != nil {
215
-				return nil, err
216
-			}
217
-		}
218
-
219
-		if payload.Len() < expectedTotal {
220
-			return nil, fmt.Errorf("cannot reassemble handshake: too many continuation records")
221
-		}
222
-
223
-		payload.Truncate(expectedTotal)
224
-	}
225
-
226
-	if payload.Len() > 0xFFFF {
227
-		return nil, fmt.Errorf("reassembled payload too large: %d bytes", payload.Len())
228
-	}
229
-
230
-	// Reconstruct a single TLS record with the reassembled payload.
231
-	result := &bytes.Buffer{}
232
-	result.Grow(tls.SizeHeader + payload.Len())
233
-	result.Write(header[:3])
234
-	binary.Write(result, binary.BigEndian, uint16(payload.Len())) //nolint:errcheck // bytes.Buffer.Write never fails
235
-	result.Write(payload.Bytes())
236
-
237
-	return result, nil
238
-}
239
-
240
-// readContinuationRecord reads the next TLS record header and appends its
241
-// full payload to dst. It returns an error if the record is not a handshake
242
-// record.
243
-func readContinuationRecord(conn io.Reader, dst *bytes.Buffer) error {
244
-	nextHeader := [tls.SizeHeader]byte{}
245
-
246
-	if _, err := io.ReadFull(conn, nextHeader[:]); err != nil {
247
-		return fmt.Errorf("cannot read continuation record header: %w", err)
248
-	}
249
-
250
-	if nextHeader[0] != tls.TypeHandshake {
251
-		return fmt.Errorf("unexpected continuation record type %#x", nextHeader[0])
252
-	}
253
-
254
-	if nextHeader[1] != 3 || nextHeader[2] != 1 {
255
-		return fmt.Errorf("unexpected continuation record version %#x %#x", nextHeader[1], nextHeader[2])
256
-	}
257
-
258
-	nextLength := int64(binary.BigEndian.Uint16(nextHeader[3:]))
259
-
260
-	if nextLength == 0 {
261
-		return fmt.Errorf("zero-length continuation record")
262
-	}
263
-
264
-	if _, err := io.CopyN(dst, conn, nextLength); err != nil {
265
-		return fmt.Errorf("cannot read continuation record payload: %w", err)
266
-	}
267
-
268
-	return nil
269
-}
270
-
271
-func parseHandshakeHeader(r io.Reader) (io.Reader, error) {
272
-	// type(1) + size(3 / uint24)
273
-	// 01 - handshake message type 0x01 (client hello)
274
-	// 00 00 f4 - 0xF4 (244) bytes of client hello data follows
275
-	header := [1 + 3]byte{}
276
-
277
-	if _, err := io.ReadFull(r, header[:]); err != nil {
278
-		return nil, fmt.Errorf("cannot read handshake header: %w", err)
279
-	}
280
-
281
-	if header[0] != TypeHandshakeClient {
282
-		return nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
283
-	}
284
-
285
-	// unfortunately there is not uint24 in golang, so we just reust header
286
-	header[0] = 0
287
-
288
-	length := int64(binary.BigEndian.Uint32(header[:]))
289
-	buf := &bytes.Buffer{}
290
-
291
-	_, err := io.CopyN(buf, r, length)
292
-
293
-	return buf, err
294
-}
295
-
296 128
 func parseHandshake(r io.Reader) (*ClientHello, error) {
297 129
 	//  A protocol version of "3,3" (meaning TLS 1.2) is given.
298 130
 	header := [2]byte{}

+ 7
- 144
mtglib/internal/tls/fake/client_side_test.go Прегледај датотеку

@@ -7,7 +7,6 @@ import (
7 7
 	"errors"
8 8
 	"io"
9 9
 	"os"
10
-	"path/filepath"
11 10
 	"testing"
12 11
 	"time"
13 12
 
@@ -397,142 +396,6 @@ func TestParseClientHelloSNI(t *testing.T) {
397 396
 	suite.Run(t, &ParseClientHelloSNITestSuite{})
398 397
 }
399 398
 
400
-// --- ReadClientHelloMulti tests ---
401
-
402
-type ReadClientHelloMultiTestSuite struct {
403
-	suite.Suite
404
-
405
-	secret mtglib.Secret
406
-}
407
-
408
-func (suite *ReadClientHelloMultiTestSuite) SetupSuite() {
409
-	parsed, err := mtglib.ParseSecret(
410
-		"ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d",
411
-	)
412
-	require.NoError(suite.T(), err)
413
-
414
-	suite.secret = parsed
415
-}
416
-
417
-func (suite *ReadClientHelloMultiTestSuite) loadSnapshot(name string) []byte {
418
-	data, err := os.ReadFile(filepath.Join("testdata", name))
419
-	require.NoError(suite.T(), err)
420
-
421
-	snapshot := &clientHelloSnapshot{}
422
-	require.NoError(suite.T(), json.Unmarshal(data, snapshot))
423
-
424
-	return snapshot.GetFull()
425
-}
426
-
427
-func (suite *ReadClientHelloMultiTestSuite) makeConn(data []byte) *parseClientHelloConnMock {
428
-	readBuf := &bytes.Buffer{}
429
-	readBuf.Write(data)
430
-
431
-	connMock := &parseClientHelloConnMock{
432
-		readBuf: readBuf,
433
-	}
434
-
435
-	connMock.
436
-		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
437
-		Twice().
438
-		Return(nil)
439
-
440
-	return connMock
441
-}
442
-
443
-func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex0() {
444
-	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
445
-	connMock := suite.makeConn(payload)
446
-	defer connMock.AssertExpectations(suite.T())
447
-
448
-	wrongSecret := mtglib.GenerateSecret("storage.googleapis.com")
449
-
450
-	result, err := fake.ReadClientHelloMulti(
451
-		connMock,
452
-		[][]byte{suite.secret.Key[:], wrongSecret.Key[:]},
453
-		suite.secret.Host,
454
-		TolerateTime,
455
-	)
456
-	suite.NoError(err)
457
-	suite.Equal(0, result.MatchedIndex)
458
-	suite.NotNil(result.Hello)
459
-}
460
-
461
-func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex1() {
462
-	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
463
-	connMock := suite.makeConn(payload)
464
-	defer connMock.AssertExpectations(suite.T())
465
-
466
-	wrongSecret := mtglib.GenerateSecret("storage.googleapis.com")
467
-
468
-	result, err := fake.ReadClientHelloMulti(
469
-		connMock,
470
-		[][]byte{wrongSecret.Key[:], suite.secret.Key[:]},
471
-		suite.secret.Host,
472
-		TolerateTime,
473
-	)
474
-	suite.NoError(err)
475
-	suite.Equal(1, result.MatchedIndex)
476
-	suite.NotNil(result.Hello)
477
-}
478
-
479
-func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex2() {
480
-	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
481
-	connMock := suite.makeConn(payload)
482
-	defer connMock.AssertExpectations(suite.T())
483
-
484
-	wrong1 := mtglib.GenerateSecret("storage.googleapis.com")
485
-	wrong2 := mtglib.GenerateSecret("storage.googleapis.com")
486
-
487
-	result, err := fake.ReadClientHelloMulti(
488
-		connMock,
489
-		[][]byte{wrong1.Key[:], wrong2.Key[:], suite.secret.Key[:]},
490
-		suite.secret.Host,
491
-		TolerateTime,
492
-	)
493
-	suite.NoError(err)
494
-	suite.Equal(2, result.MatchedIndex)
495
-	suite.NotNil(result.Hello)
496
-}
497
-
498
-func (suite *ReadClientHelloMultiTestSuite) TestNoMatchReturnsBadDigest() {
499
-	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
500
-	connMock := suite.makeConn(payload)
501
-	defer connMock.AssertExpectations(suite.T())
502
-
503
-	wrong1 := mtglib.GenerateSecret("storage.googleapis.com")
504
-	wrong2 := mtglib.GenerateSecret("storage.googleapis.com")
505
-
506
-	_, err := fake.ReadClientHelloMulti(
507
-		connMock,
508
-		[][]byte{wrong1.Key[:], wrong2.Key[:]},
509
-		suite.secret.Host,
510
-		TolerateTime,
511
-	)
512
-	suite.ErrorIs(err, fake.ErrBadDigest)
513
-}
514
-
515
-func (suite *ReadClientHelloMultiTestSuite) TestBadSnapshotReturnsBadDigest() {
516
-	payload := suite.loadSnapshot("client-hello-bad-fa2e46cdb33e2a1b.json")
517
-	connMock := suite.makeConn(payload)
518
-	defer connMock.AssertExpectations(suite.T())
519
-
520
-	_, err := fake.ReadClientHelloMulti(
521
-		connMock,
522
-		[][]byte{suite.secret.Key[:]},
523
-		suite.secret.Host,
524
-		TolerateTime,
525
-	)
526
-	suite.ErrorIs(err, fake.ErrBadDigest)
527
-}
528
-
529
-func TestReadClientHelloMulti(t *testing.T) {
530
-	t.Parallel()
531
-	suite.Run(t, &ReadClientHelloMultiTestSuite{})
532
-}
533
-
534
-// --- Fragmented TLS record tests ---
535
-
536 399
 // fragmentTLSRecord splits a single TLS record into n TLS records by
537 400
 // dividing the payload into roughly equal parts. Each part gets its own
538 401
 // TLS record header with the same record type and version.
@@ -680,7 +543,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
680 543
 				buf.Write(payload[10:])
681 544
 				return buf.Bytes()
682 545
 			},
683
-			errMsg: "unexpected continuation record type",
546
+			errMsg: "unexpected record type",
684 547
 		},
685 548
 		{
686 549
 			name: "too many continuation records",
@@ -700,7 +563,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
700 563
 				}
701 564
 				return buf.Bytes()
702 565
 			},
703
-			errMsg: "too many continuation records",
566
+			errMsg: "too many fragments",
704 567
 		},
705 568
 		{
706 569
 			name: "zero-length continuation record",
@@ -716,7 +579,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
716 579
 				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(0)))
717 580
 				return buf.Bytes()
718 581
 			},
719
-			errMsg: "zero-length continuation record",
582
+			errMsg: "cannot read record header",
720 583
 		},
721 584
 		{
722 585
 			name: "wrong continuation record version",
@@ -733,7 +596,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
733 596
 				buf.Write(payload[10:])
734 597
 				return buf.Bytes()
735 598
 			},
736
-			errMsg: "unexpected continuation record version",
599
+			errMsg: "unexpected protocol version",
737 600
 		},
738 601
 		{
739 602
 			name: "handshake message too large",
@@ -747,7 +610,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
747 610
 				buf.Write(handshakePayload)
748 611
 				return buf.Bytes()
749 612
 			},
750
-			errMsg: "handshake message too large",
613
+			errMsg: "cannot read record header",
751 614
 		},
752 615
 		{
753 616
 			name: "truncated continuation record header",
@@ -762,7 +625,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
762 625
 				buf.WriteByte(3)
763 626
 				return buf.Bytes()
764 627
 			},
765
-			errMsg: "cannot read continuation record header",
628
+			errMsg: "cannot read record header",
766 629
 		},
767 630
 		{
768 631
 			name: "truncated continuation record payload",
@@ -778,7 +641,7 @@ func (s *ParseClientHelloFragmentedTestSuite) TestReassemblyErrors() {
778 641
 				require.NoError(s.T(), binary.Write(buf, binary.BigEndian, uint16(100)))
779 642
 				return buf.Bytes()
780 643
 			},
781
-			errMsg: "cannot read continuation record payload",
644
+			errMsg: "EOF",
782 645
 		},
783 646
 	}
784 647
 

+ 158
- 0
mtglib/internal/tls/fake/utils.go Прегледај датотеку

@@ -0,0 +1,158 @@
1
+package fake
2
+
3
+import (
4
+	"bytes"
5
+	"encoding/binary"
6
+	"errors"
7
+	"fmt"
8
+	"io"
9
+
10
+	"github.com/dolonet/mtg-multi/mtglib/internal/tls"
11
+)
12
+
13
+const (
14
+	maxFragmentsCount = 10
15
+)
16
+
17
+var ErrTooManyFragments = errors.New("too many fragments")
18
+
19
+// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1
20
+// client hello can be fragmented in a series of packets:
21
+//
22
+//	Bytes on the wire:
23
+//
24
+// 16 03 01 00 F8 01 00 00 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...]
25
+// ├─────────────┤├──────────────────────────────────────────────────────────────────┤
26
+//
27
+//	TLS record       Payload (248 bytes)
28
+//	header (5B)
29
+//
30
+//	16    = Handshake
31
+//	03 01 = TLS 1.0 (record layer version)
32
+//	00 F8 = 248 bytes follow
33
+//
34
+//	01       = ClientHello (handshake type)
35
+//	00 00 F4 = 244 bytes of handshake body
36
+//	03 03    = TLS 1.2 (actual protocol version)
37
+//	...rest of ClientHello...
38
+//
39
+// Fragmented record look like:
40
+//
41
+//	Record 1:
42
+//
43
+// 16 03 01 00 03 01 00 00
44
+// ├─────────────┤├──────┤
45
+//
46
+//	TLS header    3 bytes of payload
47
+//
48
+//	16    = Handshake
49
+//	03 01 = TLS 1.0
50
+//	00 03 = only 3 bytes follow
51
+//
52
+//	01       = ClientHello type
53
+//	00 00    = first 2 bytes of the uint24 length (INCOMPLETE!)
54
+//
55
+// Record 2:
56
+// 16 03 01 00 F5 F4 03 03 [32 bytes random] [session_id] [ciphers] [SNI...]
57
+// ├─────────────┤├────────────────────────────────────────────────────────────┤
58
+//
59
+//	TLS header    remaining 245 bytes of payload
60
+//
61
+//	16    = Handshake
62
+//	03 01 = TLS 1.0
63
+//	00 F5 = 245 bytes follow
64
+//
65
+//	F4    = last byte of uint24 length (now complete: 00 00 F4 = 244)
66
+//	03 03 = TLS 1.2
67
+//	...rest of ClientHello continues...
68
+//
69
+// So it means that there could be a series of handshake packets of different
70
+// lengths. The goal of this function is to concatenate these fragments.
71
+type fragmentedHandshakeReader struct {
72
+	r             io.Reader
73
+	buf           bytes.Buffer
74
+	readFragments int
75
+}
76
+
77
+func (f *fragmentedHandshakeReader) Read(p []byte) (int, error) {
78
+	if n, err := f.buf.Read(p); err == nil {
79
+		return n, nil
80
+	}
81
+
82
+	f.buf.Reset()
83
+
84
+	for f.buf.Len() == 0 {
85
+		if f.readFragments > maxFragmentsCount {
86
+			return 0, ErrTooManyFragments
87
+		}
88
+
89
+		if err := f.parseNextFragment(); err != nil {
90
+			return 0, err
91
+		}
92
+
93
+		f.readFragments++
94
+	}
95
+
96
+	return f.buf.Read(p)
97
+}
98
+
99
+func (f *fragmentedHandshakeReader) parseNextFragment() error {
100
+	// record_type(1) + version(2) + size(2)
101
+	//   16 - type is 0x16 (handshake record)
102
+	//   03 01 - protocol version is "3,1" (also known as TLS 1.0)
103
+	//   00 f8 - 0xF8 (248) bytes of handshake message follows
104
+	header := [1 + 2 + 2]byte{}
105
+
106
+	if _, err := io.ReadFull(f.r, header[:]); err != nil {
107
+		return fmt.Errorf("cannot read record header: %w", err)
108
+	}
109
+
110
+	if header[0] != tls.TypeHandshake {
111
+		return fmt.Errorf("unexpected record type %#x", header[0])
112
+	}
113
+
114
+	if header[1] != 3 || header[2] != 1 {
115
+		return fmt.Errorf("unexpected protocol version %#x %#x", header[1], header[2])
116
+	}
117
+
118
+	length := int64(binary.BigEndian.Uint16(header[3:]))
119
+	_, err := io.CopyN(&f.buf, f.r, length)
120
+
121
+	return err
122
+}
123
+
124
+func parseClientHello(r io.Reader) (*bytes.Buffer, *bytes.Buffer, error) {
125
+	r = &fragmentedHandshakeReader{r: r}
126
+	header := [1 + 3]byte{}
127
+
128
+	if _, err := io.ReadFull(r, header[:]); err != nil {
129
+		return nil, nil, fmt.Errorf("cannot read handshake header: %w", err)
130
+	}
131
+
132
+	if header[0] != TypeHandshakeClient {
133
+		return nil, nil, fmt.Errorf("incorrect handshake type: %#x", header[0])
134
+	}
135
+
136
+	// unfortunately there is not uint24 in golang, so we just reuse header
137
+	header[0] = 0
138
+	length := int64(binary.BigEndian.Uint32(header[:]))
139
+
140
+	clientHelloCopy := &bytes.Buffer{}
141
+	clientHelloCopy.Write([]byte{tls.TypeHandshake, 3, 1})
142
+	binary.Write( //nolint: errcheck
143
+		clientHelloCopy,
144
+		binary.BigEndian,
145
+		// 1 for handshake type
146
+		// 3 for handshake length
147
+		uint16(1+3+length),
148
+	)
149
+	clientHelloCopy.WriteByte(TypeHandshakeClient)
150
+	clientHelloCopy.Write(header[1:])
151
+
152
+	handshakeCopy := &bytes.Buffer{}
153
+	writer := io.MultiWriter(clientHelloCopy, handshakeCopy)
154
+
155
+	_, err := io.CopyN(writer, r, length)
156
+
157
+	return clientHelloCopy, handshakeCopy, err
158
+}

Loading…
Откажи
Сачувај