Преглед изворни кода

Merge pull request #140 from 9seconds/pools

Add pool support everywhere
tags/v1.0.4^2
Sergey Arkhipov пре 6 година
родитељ
комит
be1a47287a
No account linked to committer's email address

+ 6
- 1
faketls/client_protocol.go Прегледај датотеку

@@ -63,7 +63,12 @@ func (c *ClientProtocol) tlsHandshake(conn io.ReadWriter) error {
63 63
 		return fmt.Errorf("cannot read initial record: %w", err)
64 64
 	}
65 65
 
66
-	clientHello, err := tlstypes.ParseClientHello(helloRecord.Data.Bytes())
66
+	buf := acquireBytesBuffer()
67
+	defer releaseBytesBuffer(buf)
68
+
69
+	helloRecord.Data.WriteBytes(buf)
70
+
71
+	clientHello, err := tlstypes.ParseClientHello(buf.Bytes())
67 72
 	if err != nil {
68 73
 		return fmt.Errorf("cannot parse client hello: %w", err)
69 74
 	}

+ 11
- 8
faketls/cloak.go Прегледај датотеку

@@ -28,15 +28,9 @@ func cloak(one, another io.ReadWriteCloser) {
28 28
 
29 29
 	wg.Add(2)
30 30
 
31
-	go func() {
32
-		defer wg.Done()
33
-		io.Copy(one, another) // nolint: errcheck
34
-	}()
31
+	go cloakPipe(one, another, wg)
35 32
 
36
-	go func() {
37
-		defer wg.Done()
38
-		io.Copy(another, one) // nolint: errcheck
39
-	}()
33
+	go cloakPipe(another, one, wg)
40 34
 
41 35
 	go func() {
42 36
 		wg.Wait()
@@ -69,3 +63,12 @@ func cloak(one, another io.ReadWriteCloser) {
69 63
 
70 64
 	<-ctx.Done()
71 65
 }
66
+
67
+func cloakPipe(one io.Writer, another io.Reader, wg *sync.WaitGroup) {
68
+	defer wg.Done()
69
+
70
+	buf := acquireCloakBuffer()
71
+	defer releaseCloakBuffer(buf)
72
+
73
+	io.CopyBuffer(one, another, *buf) // nolint: errcheck
74
+}

+ 39
- 0
faketls/pools.go Прегледај датотеку

@@ -0,0 +1,39 @@
1
+package faketls
2
+
3
+import (
4
+	"bytes"
5
+	"sync"
6
+)
7
+
8
+const cloakBufferSize = 1024
9
+
10
+var (
11
+	poolBytesBuffer = sync.Pool{
12
+		New: func() interface{} {
13
+			return &bytes.Buffer{}
14
+		},
15
+	}
16
+	poolCloakBuffer = sync.Pool{
17
+		New: func() interface{} {
18
+			rv := make([]byte, cloakBufferSize)
19
+			return &rv
20
+		},
21
+	}
22
+)
23
+
24
+func acquireBytesBuffer() *bytes.Buffer {
25
+	return poolBytesBuffer.Get().(*bytes.Buffer)
26
+}
27
+
28
+func acquireCloakBuffer() *[]byte {
29
+	return poolCloakBuffer.Get().(*[]byte)
30
+}
31
+
32
+func releaseBytesBuffer(buf *bytes.Buffer) {
33
+	buf.Reset()
34
+	poolBytesBuffer.Put(buf)
35
+}
36
+
37
+func releaseCloakBuffer(buf *[]byte) {
38
+	poolCloakBuffer.Put(buf)
39
+}

+ 1
- 1
tlstypes/client_hello.go Прегледај датотеку

@@ -25,7 +25,7 @@ func (c ClientHello) Digest() []byte {
25 25
 	}
26 26
 
27 27
 	mac := hmac.New(sha256.New, config.C.Secret)
28
-	mac.Write(rec.Bytes()) // nolint: errcheck
28
+	rec.WriteBytes(mac)
29 29
 	computedDigest := mac.Sum(nil)
30 30
 
31 31
 	for i := range computedDigest {

+ 10
- 3
tlstypes/consts.go Прегледај датотеку

@@ -1,5 +1,7 @@
1 1
 package tlstypes
2 2
 
3
+import "io"
4
+
3 5
 type RecordType uint8
4 6
 
5 7
 const (
@@ -69,11 +71,16 @@ var (
69 71
 )
70 72
 
71 73
 type Byter interface {
72
-	Bytes() []byte
74
+	WriteBytes(io.Writer)
75
+	Len() int
73 76
 }
74 77
 
75 78
 type RawBytes []byte
76 79
 
77
-func (r RawBytes) Bytes() []byte {
78
-	return []byte(r)
80
+func (r RawBytes) WriteBytes(writer io.Writer) {
81
+	writer.Write(r) // nolint: errcheck
82
+}
83
+
84
+func (r RawBytes) Len() int {
85
+	return len(r)
79 86
 }

+ 16
- 9
tlstypes/handshake.go Прегледај датотеку

@@ -1,7 +1,7 @@
1 1
 package tlstypes
2 2
 
3 3
 import (
4
-	"bytes"
4
+	"io"
5 5
 
6 6
 	"github.com/9seconds/mtg/utils"
7 7
 )
@@ -14,24 +14,31 @@ type Handshake struct {
14 14
 	Tail      Byter
15 15
 }
16 16
 
17
-func (h *Handshake) Bytes() []byte {
18
-	buf := bytes.Buffer{}
19
-	packetBuf := bytes.Buffer{}
17
+func (h *Handshake) WriteBytes(writer io.Writer) {
18
+	packetBuf := acquireBytesBuffer()
19
+	defer releaseBytesBuffer(packetBuf)
20 20
 
21
-	buf.WriteByte(byte(h.Type))
21
+	writer.Write([]byte{byte(h.Type)}) // nolint: errcheck
22 22
 
23 23
 	packetBuf.Write(h.Version.Bytes())
24 24
 	packetBuf.Write(h.Random[:])
25 25
 	packetBuf.WriteByte(byte(len(h.SessionID)))
26 26
 	packetBuf.Write(h.SessionID)
27
-	packetBuf.Write(h.Tail.Bytes())
27
+	h.Tail.WriteBytes(packetBuf)
28 28
 
29 29
 	sizeUint24 := utils.ToUint24(uint32(packetBuf.Len()))
30 30
 	sizeUint24Bytes := sizeUint24[:]
31 31
 	sizeUint24Bytes[0], sizeUint24Bytes[2] = sizeUint24Bytes[2], sizeUint24Bytes[0]
32 32
 
33
-	buf.Write(sizeUint24Bytes)
34
-	packetBuf.WriteTo(&buf) // nolint: errcheck
33
+	writer.Write(sizeUint24Bytes) // nolint: errcheck
34
+	packetBuf.WriteTo(writer)     // nolint: errcheck
35
+}
36
+
37
+func (h *Handshake) Len() int {
38
+	buf := acquireBytesBuffer()
39
+	defer releaseBytesBuffer(buf)
40
+
41
+	h.WriteBytes(buf)
35 42
 
36
-	return buf.Bytes()
43
+	return buf.Len()
37 44
 }

+ 23
- 0
tlstypes/pools.go Прегледај датотеку

@@ -0,0 +1,23 @@
1
+package tlstypes
2
+
3
+import (
4
+	"bytes"
5
+	"sync"
6
+)
7
+
8
+var (
9
+	poolBytesBuffer = sync.Pool{
10
+		New: func() interface{} {
11
+			return &bytes.Buffer{}
12
+		},
13
+	}
14
+)
15
+
16
+func acquireBytesBuffer() *bytes.Buffer {
17
+	return poolBytesBuffer.Get().(*bytes.Buffer)
18
+}
19
+
20
+func releaseBytesBuffer(buf *bytes.Buffer) {
21
+	buf.Reset()
22
+	poolBytesBuffer.Put(buf)
23
+}

+ 8
- 9
tlstypes/record.go Прегледај датотеку

@@ -15,16 +15,15 @@ type Record struct {
15 15
 	Data    Byter
16 16
 }
17 17
 
18
-func (r Record) Bytes() []byte {
19
-	buf := bytes.Buffer{}
20
-	data := r.Data.Bytes()
21
-
22
-	buf.WriteByte(byte(r.Type))
23
-	buf.Write(r.Version.Bytes())
24
-	binary.Write(&buf, binary.BigEndian, uint16(len(data))) // nolint: errcheck
25
-	buf.Write(data)
18
+func (r Record) WriteBytes(writer io.Writer) {
19
+	writer.Write([]byte{byte(r.Type)})                           // nolint: errcheck
20
+	writer.Write(r.Version.Bytes())                              // nolint: errcheck
21
+	binary.Write(writer, binary.BigEndian, uint16(r.Data.Len())) // nolint: errcheck
22
+	r.Data.WriteBytes(writer)
23
+}
26 24
 
27
-	return buf.Bytes()
25
+func (r Record) Len() int {
26
+	return 1 + 2 + 2 + r.Data.Len()
28 27
 }
29 28
 
30 29
 func ReadRecord(reader io.Reader) (Record, error) {

+ 6
- 3
tlstypes/server_hello.go Прегледај датотеку

@@ -20,20 +20,22 @@ type ServerHello struct {
20 20
 }
21 21
 
22 22
 func (s ServerHello) WelcomePacket() []byte {
23
+	buf := &bytes.Buffer{}
24
+
23 25
 	s.Random = [32]byte{}
24 26
 	rec := Record{
25 27
 		Type:    RecordTypeHandshake,
26 28
 		Version: Version12,
27 29
 		Data:    &s,
28 30
 	}
29
-	buf := bytes.NewBuffer(rec.Bytes())
31
+	rec.WriteBytes(buf)
30 32
 
31 33
 	recChangeCipher := Record{
32 34
 		Type:    RecordTypeChangeCipherSpec,
33 35
 		Version: Version12,
34 36
 		Data:    RawBytes([]byte{0x01}),
35 37
 	}
36
-	buf.Write(recChangeCipher.Bytes())
38
+	recChangeCipher.WriteBytes(buf)
37 39
 
38 40
 	hostCert := make([]byte, 1024+mrand.Intn(3092))
39 41
 	rand.Read(hostCert) // nolint: errcheck
@@ -43,7 +45,8 @@ func (s ServerHello) WelcomePacket() []byte {
43 45
 		Version: Version12,
44 46
 		Data:    RawBytes(hostCert),
45 47
 	}
46
-	buf.Write(recData.Bytes())
48
+	recData.WriteBytes(buf)
49
+
47 50
 	packet := buf.Bytes()
48 51
 
49 52
 	mac := hmac.New(sha256.New, config.C.Secret)

+ 5
- 4
wrappers/packet/mtproto_frame.go Прегледај датотеку

@@ -42,7 +42,9 @@ type wrapperMtprotoFrame struct {
42 42
 }
43 43
 
44 44
 func (w *wrapperMtprotoFrame) Read() (conntypes.Packet, error) { // nolint: funlen
45
-	buf := &bytes.Buffer{}
45
+	buf := acquireMtprotoFrameBytesBuffer()
46
+	defer releaseMtprotoFrameBytesBuffer(buf)
47
+
46 48
 	sum := crc32.NewIEEE()
47 49
 	writer := io.MultiWriter(buf, sum)
48 50
 
@@ -71,7 +73,6 @@ func (w *wrapperMtprotoFrame) Read() (conntypes.Packet, error) { // nolint: funl
71 73
 	}
72 74
 
73 75
 	buf.Reset()
74
-	buf.Grow(int(messageLength) - 4 - 4)
75 76
 
76 77
 	if _, err := io.CopyN(writer, w.parent, int64(messageLength)-4-4); err != nil {
77 78
 		return nil, fmt.Errorf("cannot read the message frame: %w", err)
@@ -113,8 +114,8 @@ func (w *wrapperMtprotoFrame) Write(p conntypes.Packet) error {
113 114
 	messageLength := 4 + 4 + len(p) + 4
114 115
 	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
115 116
 
116
-	buf := &bytes.Buffer{}
117
-	buf.Grow(messageLength + paddingLength)
117
+	buf := acquireMtprotoFrameBytesBuffer()
118
+	defer releaseMtprotoFrameBytesBuffer(buf)
118 119
 
119 120
 	binary.Write(buf, binary.LittleEndian, uint32(messageLength)) // nolint: errcheck
120 121
 	binary.Write(buf, binary.LittleEndian, w.writeSeqNo)          // nolint: errcheck

+ 23
- 0
wrappers/packet/pools.go Прегледај датотеку

@@ -0,0 +1,23 @@
1
+package packet
2
+
3
+import (
4
+	"bytes"
5
+	"sync"
6
+)
7
+
8
+var (
9
+	poolMtprotoFrameBytesBuffer = sync.Pool{
10
+		New: func() interface{} {
11
+			return &bytes.Buffer{}
12
+		},
13
+	}
14
+)
15
+
16
+func acquireMtprotoFrameBytesBuffer() *bytes.Buffer {
17
+	return poolMtprotoFrameBytesBuffer.Get().(*bytes.Buffer)
18
+}
19
+
20
+func releaseMtprotoFrameBytesBuffer(buf *bytes.Buffer) {
21
+	buf.Reset()
22
+	poolMtprotoFrameBytesBuffer.Put(buf)
23
+}

+ 3
- 1
wrappers/packetack/client_abridged.go Прегледај датотеку

@@ -88,7 +88,9 @@ func (w *wrapperClientAbridged) Write(packet conntypes.Packet, acks *conntypes.C
88 88
 		return nil
89 89
 	case packetLength < clientAbridgedLargePacketLength:
90 90
 		length24 := utils.ToUint24(uint32(packetLength))
91
-		buf := bytes.Buffer{}
91
+
92
+		buf := acquireClientBytesBuffer()
93
+		defer releaseClientBytesBuffer(buf)
92 94
 
93 95
 		buf.WriteByte(byte(clientAbridgedSmallPacketLength))
94 96
 		buf.Write(length24[:])

+ 4
- 3
wrappers/packetack/client_intermediate_secure.go Прегледај датотеку

@@ -1,7 +1,6 @@
1 1
 package packetack
2 2
 
3 3
 import (
4
-	"bytes"
5 4
 	"encoding/binary"
6 5
 	"fmt"
7 6
 	"math/rand"
@@ -35,11 +34,13 @@ func (w *wrapperClientIntermediateSecure) Write(packet conntypes.Packet, acks *c
35 34
 		return nil
36 35
 	}
37 36
 
38
-	buf := bytes.Buffer{}
37
+	buf := acquireClientBytesBuffer()
38
+	defer releaseClientBytesBuffer(buf)
39
+
39 40
 	paddingLength := rand.Intn(4)
40 41
 	buf.Grow(4 + len(packet) + paddingLength)
41 42
 
42
-	binary.Write(&buf, binary.LittleEndian, uint32(len(packet)+paddingLength)) // nolint: errcheck
43
+	binary.Write(buf, binary.LittleEndian, uint32(len(packet)+paddingLength)) // nolint: errcheck
43 44
 	buf.Write(packet)
44 45
 	buf.Write(make([]byte, paddingLength))
45 46
 

+ 23
- 0
wrappers/packetack/pools.go Прегледај датотеку

@@ -0,0 +1,23 @@
1
+package packetack
2
+
3
+import (
4
+	"bytes"
5
+	"sync"
6
+)
7
+
8
+var (
9
+	poolClientBytesBuffer = sync.Pool{
10
+		New: func() interface{} {
11
+			return &bytes.Buffer{}
12
+		},
13
+	}
14
+)
15
+
16
+func acquireClientBytesBuffer() *bytes.Buffer {
17
+	return poolClientBytesBuffer.Get().(*bytes.Buffer)
18
+}
19
+
20
+func releaseClientBytesBuffer(buf *bytes.Buffer) {
21
+	buf.Reset()
22
+	poolClientBytesBuffer.Put(buf)
23
+}

+ 2
- 1
wrappers/packetack/proxy.go Прегледај датотеку

@@ -23,8 +23,8 @@ type wrapperProxy struct {
23 23
 
24 24
 func (w *wrapperProxy) Write(packet conntypes.Packet, acks *conntypes.ConnectionAcks) error {
25 25
 	buf := bytes.Buffer{}
26
-
27 26
 	flags := w.flags
27
+
28 28
 	if acks.Quick {
29 29
 		flags |= rpc.ProxyRequestFlagsQuickAck
30 30
 	}
@@ -43,6 +43,7 @@ func (w *wrapperProxy) Write(packet conntypes.Packet, acks *conntypes.Connection
43 43
 	buf.WriteByte(byte(len(config.C.AdTag)))
44 44
 	buf.Write(config.C.AdTag)
45 45
 	buf.Write(make([]byte, (4-buf.Len()%4)%4))
46
+	buf.Grow(len(packet))
46 47
 	buf.Write(packet)
47 48
 
48 49
 	return w.proxy.Write(buf.Bytes())

+ 13
- 3
wrappers/stream/faketls.go Прегледај датотеку

@@ -1,6 +1,7 @@
1 1
 package stream
2 2
 
3 3
 import (
4
+	"bytes"
4 5
 	"errors"
5 6
 	"fmt"
6 7
 	"net"
@@ -39,13 +40,19 @@ func (w *wrapperFakeTLS) WriteTimeout(p []byte, timeout time.Duration) (int, err
39 40
 func (w *wrapperFakeTLS) write(p []byte, writeFunc func([]byte) (int, error)) (int, error) {
40 41
 	sum := 0
41 42
 
43
+	buf := acquireBytesBuffer()
44
+	defer releaseBytesBuffer(buf)
45
+
42 46
 	for _, v := range tlstypes.MakeRecords(p) {
43
-		_, err := writeFunc(v.Bytes())
47
+		buf.Reset()
48
+		v.WriteBytes(buf)
49
+
50
+		_, err := writeFunc(buf.Bytes())
44 51
 		if err != nil {
45 52
 			return sum, err
46 53
 		}
47 54
 
48
-		sum += len(v.Data.Bytes())
55
+		sum += v.Data.Len()
49 56
 	}
50 57
 
51 58
 	return sum, nil
@@ -86,7 +93,10 @@ func NewFakeTLS(socket conntypes.StreamReadWriteCloser) conntypes.StreamReadWrit
86 93
 			switch rec.Type {
87 94
 			case tlstypes.RecordTypeChangeCipherSpec:
88 95
 			case tlstypes.RecordTypeApplicationData:
89
-				return rec.Data.Bytes(), nil
96
+				buf := &bytes.Buffer{}
97
+				rec.Data.WriteBytes(buf)
98
+
99
+				return buf.Bytes(), nil
90 100
 			default:
91 101
 				return nil, fmt.Errorf("unsupported record type %v", rec.Type)
92 102
 			}

+ 3
- 2
wrappers/stream/mtproto_cipher.go Прегледај датотеку

@@ -1,7 +1,6 @@
1 1
 package stream
2 2
 
3 3
 import (
4
-	"bytes"
5 4
 	"crypto/aes"
6 5
 	"crypto/cipher"
7 6
 	"crypto/md5"  // nolint: gosec
@@ -54,7 +53,9 @@ func mtprotoDeriveKeys(purpose mtprotoCipherPurpose,
54 53
 	resp *rpc.NonceResponse,
55 54
 	client, remote *net.TCPAddr,
56 55
 	secret []byte) ([]byte, []byte) {
57
-	message := bytes.Buffer{}
56
+	message := acquireBytesBuffer()
57
+	defer releaseBytesBuffer(message)
58
+
58 59
 	message.Write(resp.Nonce)   // nolint: gosec
59 60
 	message.Write(req.Nonce)    // nolint: gosec
60 61
 	message.Write(req.CryptoTS) // nolint: gosec

+ 4
- 23
wrappers/stream/obfuscated2.go Прегледај датотеку

@@ -1,11 +1,9 @@
1 1
 package stream
2 2
 
3 3
 import (
4
-	"bytes"
5 4
 	"crypto/cipher"
6 5
 	"fmt"
7 6
 	"net"
8
-	"sync"
9 7
 	"time"
10 8
 
11 9
 	"go.uber.org/zap"
@@ -13,23 +11,6 @@ import (
13 11
 	"github.com/9seconds/mtg/conntypes"
14 12
 )
15 13
 
16
-var (
17
-	poolWrapperObfuscated2WritePool = sync.Pool{
18
-		New: func() interface{} {
19
-			return &bytes.Buffer{}
20
-		},
21
-	}
22
-)
23
-
24
-func poolWrapperObfuscated2WritePoolAcquire() *bytes.Buffer {
25
-	return poolWrapperObfuscated2WritePool.Get().(*bytes.Buffer)
26
-}
27
-
28
-func poolWrapperObfuscated2WritePoolRelease(buf *bytes.Buffer) {
29
-	buf.Reset()
30
-	poolWrapperObfuscated2WritePool.Put(buf)
31
-}
32
-
33 14
 type wrapperObfuscated2 struct {
34 15
 	encryptor cipher.Stream
35 16
 	decryptor cipher.Stream
@@ -59,8 +40,8 @@ func (w *wrapperObfuscated2) Read(p []byte) (int, error) {
59 40
 }
60 41
 
61 42
 func (w *wrapperObfuscated2) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
62
-	buffer := poolWrapperObfuscated2WritePoolAcquire()
63
-	defer poolWrapperObfuscated2WritePoolRelease(buffer)
43
+	buffer := acquireBytesBuffer()
44
+	defer releaseBytesBuffer(buffer)
64 45
 
65 46
 	buffer.Write(p)
66 47
 
@@ -72,8 +53,8 @@ func (w *wrapperObfuscated2) WriteTimeout(p []byte, timeout time.Duration) (int,
72 53
 }
73 54
 
74 55
 func (w *wrapperObfuscated2) Write(p []byte) (int, error) {
75
-	buffer := poolWrapperObfuscated2WritePoolAcquire()
76
-	defer poolWrapperObfuscated2WritePoolRelease(buffer)
56
+	buffer := acquireBytesBuffer()
57
+	defer releaseBytesBuffer(buffer)
77 58
 
78 59
 	buffer.Write(p)
79 60
 

+ 23
- 0
wrappers/stream/pools.go Прегледај датотеку

@@ -0,0 +1,23 @@
1
+package stream
2
+
3
+import (
4
+	"bytes"
5
+	"sync"
6
+)
7
+
8
+var (
9
+	poolBytesBuffer = sync.Pool{
10
+		New: func() interface{} {
11
+			return &bytes.Buffer{}
12
+		},
13
+	}
14
+)
15
+
16
+func acquireBytesBuffer() *bytes.Buffer {
17
+	return poolBytesBuffer.Get().(*bytes.Buffer)
18
+}
19
+
20
+func releaseBytesBuffer(buf *bytes.Buffer) {
21
+	buf.Reset()
22
+	poolBytesBuffer.Put(buf)
23
+}

Loading…
Откажи
Сачувај