Sfoglia il codice sorgente

Correctly manage partial writes

tags/v2.0.0-rc1
9seconds 5 anni fa
parent
commit
3992054560

+ 25
- 6
mtglib/internal/faketls/record/pools.go Vedi File

@@ -1,12 +1,22 @@
1 1
 package record
2 2
 
3
-import "sync"
3
+import (
4
+	"bytes"
5
+	"sync"
6
+)
4 7
 
5
-var recordPool = sync.Pool{
6
-	New: func() interface{} {
7
-		return &Record{}
8
-	},
9
-}
8
+var (
9
+	recordPool = sync.Pool{
10
+		New: func() interface{} {
11
+			return &Record{}
12
+		},
13
+	}
14
+	bytesBufferPool = sync.Pool{
15
+		New: func() interface{} {
16
+			return &bytes.Buffer{}
17
+		},
18
+	}
19
+)
10 20
 
11 21
 func AcquireRecord() *Record {
12 22
 	return recordPool.Get().(*Record)
@@ -16,3 +26,12 @@ func ReleaseRecord(r *Record) {
16 26
 	r.Reset()
17 27
 	recordPool.Put(r)
18 28
 }
29
+
30
+func acquireBytesBuffer() *bytes.Buffer {
31
+	return bytesBufferPool.Get().(*bytes.Buffer)
32
+}
33
+
34
+func releaseBytesBuffer(buf *bytes.Buffer) {
35
+	buf.Reset()
36
+	bytesBufferPool.Put(buf)
37
+}

+ 11
- 15
mtglib/internal/faketls/record/record.go Vedi File

@@ -61,26 +61,22 @@ func (r *Record) Read(reader io.Reader) error {
61 61
 }
62 62
 
63 63
 func (r *Record) Dump(writer io.Writer) error {
64
-	buf := [2]byte{byte(r.Type), 0}
64
+	buf := acquireBytesBuffer()
65
+	defer releaseBytesBuffer(buf)
65 66
 
66
-	if _, err := writer.Write(buf[:1]); err != nil {
67
-		return fmt.Errorf("cannot dump type: %w", err)
68
-	}
69
-
70
-	binary.BigEndian.PutUint16(buf[:], uint16(r.Version))
67
+	bufSlice := [2]byte{byte(r.Type), 0}
68
+	buf.Write(bufSlice[:1])
71 69
 
72
-	if _, err := writer.Write(buf[:]); err != nil {
73
-		return fmt.Errorf("cannot dump version: %w", err)
74
-	}
70
+	binary.BigEndian.PutUint16(bufSlice[:], uint16(r.Version))
71
+	buf.Write(bufSlice[:])
75 72
 
76
-	binary.BigEndian.PutUint16(buf[:], uint16(r.Payload.Len()))
73
+	binary.BigEndian.PutUint16(bufSlice[:], uint16(r.Payload.Len()))
74
+	buf.Write(bufSlice[:])
77 75
 
78
-	if _, err := writer.Write(buf[:]); err != nil {
79
-		return fmt.Errorf("cannot dump payload length: %w", err)
80
-	}
76
+	buf.Write(r.Payload.Bytes())
81 77
 
82
-	if _, err := writer.Write(r.Payload.Bytes()); err != nil {
83
-		return fmt.Errorf("cannot dump payload: %w", err)
78
+	if _, err := buf.WriteTo(writer); err != nil {
79
+		return fmt.Errorf("cannot dump record: %w", err)
84 80
 	}
85 81
 
86 82
 	return nil

+ 1
- 1
mtglib/internal/obfuscated2/client_handshake_test.go Vedi File

@@ -58,7 +58,7 @@ func (suite *ClientHandshakeTestSuite) TestOk() {
58 58
 					copy(writeData, arr)
59 59
 				})
60 60
 
61
-			conn := &obfuscated2.Conn{
61
+			conn := obfuscated2.Conn{
62 62
 				Conn:      connMock,
63 63
 				Encryptor: encryptor,
64 64
 				Decryptor: decryptor,

+ 12
- 7
mtglib/internal/obfuscated2/conn.go Vedi File

@@ -10,11 +10,9 @@ type Conn struct {
10 10
 
11 11
 	Encryptor cipher.Stream
12 12
 	Decryptor cipher.Stream
13
-
14
-	writeBuf []byte
15 13
 }
16 14
 
17
-func (c *Conn) Read(p []byte) (int, error) {
15
+func (c Conn) Read(p []byte) (int, error) {
18 16
 	n, err := c.Conn.Read(p)
19 17
 	if err != nil {
20 18
 		return n, err // nolint: wrapcheck
@@ -25,9 +23,16 @@ func (c *Conn) Read(p []byte) (int, error) {
25 23
 	return n, nil
26 24
 }
27 25
 
28
-func (c *Conn) Write(p []byte) (int, error) {
29
-	c.writeBuf = append(c.writeBuf[:0], p...)
30
-	c.Encryptor.XORKeyStream(c.writeBuf, p)
26
+func (c Conn) Write(p []byte) (int, error) {
27
+	buf := acquireBytesBuffer()
28
+	defer releaseBytesBuffer(buf)
29
+
30
+	buf.Write(p)
31
+
32
+	payload := buf.Bytes()
33
+	c.Encryptor.XORKeyStream(payload, payload)
34
+
35
+	n, err := buf.WriteTo(c.Conn)
31 36
 
32
-	return c.Conn.Write(c.writeBuf)
37
+	return int(n), err // nolint: wrapcheck
33 38
 }

+ 22
- 5
mtglib/internal/obfuscated2/pools.go Vedi File

@@ -1,16 +1,24 @@
1 1
 package obfuscated2
2 2
 
3 3
 import (
4
+	"bytes"
4 5
 	"crypto/sha256"
5 6
 	"hash"
6 7
 	"sync"
7 8
 )
8 9
 
9
-var sha256HasherPool = sync.Pool{
10
-	New: func() interface{} {
11
-		return sha256.New()
12
-	},
13
-}
10
+var (
11
+	sha256HasherPool = sync.Pool{
12
+		New: func() interface{} {
13
+			return sha256.New()
14
+		},
15
+	}
16
+	bytesBufferPool = sync.Pool{
17
+		New: func() interface{} {
18
+			return &bytes.Buffer{}
19
+		},
20
+	}
21
+)
14 22
 
15 23
 func acquireSha256Hasher() hash.Hash {
16 24
 	return sha256HasherPool.Get().(hash.Hash)
@@ -20,3 +28,12 @@ func releaseSha256Hasher(h hash.Hash) {
20 28
 	h.Reset()
21 29
 	sha256HasherPool.Put(h)
22 30
 }
31
+
32
+func acquireBytesBuffer() *bytes.Buffer {
33
+	return bytesBufferPool.Get().(*bytes.Buffer)
34
+}
35
+
36
+func releaseBytesBuffer(buf *bytes.Buffer) {
37
+	buf.Reset()
38
+	bytesBufferPool.Put(buf)
39
+}

+ 2
- 2
mtglib/internal/obfuscated2/server_handshake_test.go Vedi File

@@ -17,7 +17,7 @@ type ServerHandshakeTestSuite struct {
17 17
 	suite.Suite
18 18
 
19 19
 	connMock  *testlib.NetConnMock
20
-	proxyConn *obfuscated2.Conn
20
+	proxyConn obfuscated2.Conn
21 21
 	encryptor cipher.Stream
22 22
 	decryptor cipher.Stream
23 23
 }
@@ -29,7 +29,7 @@ func (suite *ServerHandshakeTestSuite) SetupTest() {
29 29
 	encryptor, decryptor, err := obfuscated2.ServerHandshake(buf)
30 30
 	suite.NoError(err)
31 31
 
32
-	suite.proxyConn = &obfuscated2.Conn{
32
+	suite.proxyConn = obfuscated2.Conn{
33 33
 		Conn:      suite.connMock,
34 34
 		Encryptor: encryptor,
35 35
 		Decryptor: decryptor,

+ 2
- 2
mtglib/proxy.go Vedi File

@@ -192,7 +192,7 @@ func (p *Proxy) doObfuscated2Handshake(ctx *streamContext) error {
192 192
 
193 193
 	ctx.dc = dc
194 194
 	ctx.logger = ctx.logger.BindInt("dc", dc)
195
-	ctx.clientConn = &obfuscated2.Conn{
195
+	ctx.clientConn = obfuscated2.Conn{
196 196
 		Conn:      ctx.clientConn,
197 197
 		Encryptor: encryptor,
198 198
 		Decryptor: decryptor,
@@ -214,7 +214,7 @@ func (p *Proxy) doTelegramCall(ctx *streamContext) error {
214 214
 		return fmt.Errorf("cannot perform obfuscated2 handshake: %w", err)
215 215
 	}
216 216
 
217
-	ctx.telegramConn = &obfuscated2.Conn{
217
+	ctx.telegramConn = obfuscated2.Conn{
218 218
 		Conn: connTelegramTraffic{
219 219
 			Conn:   conn,
220 220
 			connID: ctx.connID,

Loading…
Annulla
Salva