Преглед на файлове

Add buffered reader

tags/0.9
9seconds преди 7 години
родител
ревизия
586ffab01b
променени са 4 файла, в които са добавени 100 реда и са изтрити 114 реда
  1. 3
    7
      mtproto/rpc/rpc_proxy_request.go
  2. 42
    61
      mtproto/wrappers/frame.go
  3. 17
    46
      wrappers/blockcipherrwc.go
  4. 38
    0
      wrappers/buffered_reader.go

+ 3
- 7
mtproto/rpc/rpc_proxy_request.go Целия файл

@@ -29,10 +29,9 @@ type RPCProxyRequest struct {
29 29
 	LocalIPPort  [rpcProxyRequestIPPortLength]byte
30 30
 	ADTag        []byte
31 31
 	Extras       Extras
32
-	Message      *bytes.Buffer
33 32
 }
34 33
 
35
-func (r *RPCProxyRequest) Bytes() []byte {
34
+func (r *RPCProxyRequest) Bytes(message []byte) []byte {
36 35
 	buf := &bytes.Buffer{}
37 36
 
38 37
 	flags := r.Flags
@@ -40,8 +39,7 @@ func (r *RPCProxyRequest) Bytes() []byte {
40 39
 		flags |= RPCProxyRequestFlagsQuickAck
41 40
 	}
42 41
 
43
-	messageBytes := r.Message.Bytes()
44
-	if bytes.HasPrefix(messageBytes, rpcProxyRequestFlagsEncryptedPrefix[:]) {
42
+	if bytes.HasPrefix(message, rpcProxyRequestFlagsEncryptedPrefix[:]) {
45 43
 		flags |= RPCProxyRequestFlagsEncrypted
46 44
 	}
47 45
 
@@ -58,9 +56,7 @@ func (r *RPCProxyRequest) Bytes() []byte {
58 56
 	for i := 0; i < (buf.Len() % 4); i++ {
59 57
 		buf.WriteByte(0x00)
60 58
 	}
61
-	if r.Message != nil {
62
-		buf.Write(messageBytes)
63
-	}
59
+	buf.Write(message)
64 60
 
65 61
 	return buf.Bytes()
66 62
 }

+ 42
- 61
mtproto/wrappers/frame.go Целия файл

@@ -22,11 +22,11 @@ const (
22 22
 var frameRWCPadding = [4]byte{0x04, 0x00, 0x00, 0x00}
23 23
 
24 24
 type FrameRWC struct {
25
-	conn wrappers.ReadWriteCloserWithAddr
25
+	wrappers.BufferedReader
26 26
 
27
+	conn       wrappers.ReadWriteCloserWithAddr
27 28
 	readSeqNo  int32
28 29
 	writeSeqNo int32
29
-	readBuf    *bytes.Buffer
30 30
 }
31 31
 
32 32
 func (f *FrameRWC) Write(buf []byte) (int, error) {
@@ -54,51 +54,49 @@ func (f *FrameRWC) Write(buf []byte) (int, error) {
54 54
 }
55 55
 
56 56
 func (f *FrameRWC) Read(p []byte) (int, error) {
57
-	if f.readBuf.Len() > 0 {
58
-		return f.flush(p)
59
-	}
60
-
61
-	buf := &bytes.Buffer{}
62
-	for {
63
-		buf.Reset()
64
-		if _, err := io.CopyN(buf, f.conn, 4); err != nil {
65
-			return 0, errors.Annotate(err, "Cannot read frame padding")
57
+	return f.BufferedRead(p, func() error {
58
+		buf := &bytes.Buffer{}
59
+		for {
60
+			buf.Reset()
61
+			if _, err := io.CopyN(buf, f.conn, 4); err != nil {
62
+				return errors.Annotate(err, "Cannot read frame padding")
63
+			}
64
+			if !bytes.Equal(buf.Bytes(), frameRWCPadding[:]) {
65
+				break
66
+			}
66 67
 		}
67
-		if !bytes.Equal(buf.Bytes(), frameRWCPadding[:]) {
68
-			break
69
-		}
70
-	}
71 68
 
72
-	messageLength := binary.LittleEndian.Uint32(buf.Bytes())
73
-	if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
74
-		return 0, errors.Errorf("Incorrect frame message length %d", messageLength)
75
-	}
76
-	sum := crc32.NewIEEE()
77
-	sum.Write(buf.Bytes())
69
+		messageLength := binary.LittleEndian.Uint32(buf.Bytes())
70
+		if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
71
+			return errors.Errorf("Incorrect frame message length %d", messageLength)
72
+		}
73
+		sum := crc32.NewIEEE()
74
+		sum.Write(buf.Bytes())
78 75
 
79
-	buf.Reset()
80
-	buf.Grow(int(messageLength) - 4) // -4 because we already read the first number
81
-	if _, err := io.CopyN(buf, f.conn, int64(messageLength)-4); err != nil {
82
-		return 0, errors.Annotate(err, "Cannot read the message frame")
83
-	}
84
-	sum.Write(buf.Bytes())
76
+		buf.Reset()
77
+		buf.Grow(int(messageLength) - 4) // -4 because we already read the first number
78
+		if _, err := io.CopyN(buf, f.conn, int64(messageLength)-4); err != nil {
79
+			return errors.Annotate(err, "Cannot read the message frame")
80
+		}
81
+		sum.Write(buf.Bytes())
85 82
 
86
-	var seqNo int32
87
-	binary.Read(buf, binary.LittleEndian, seqNo)
88
-	if seqNo != f.readSeqNo {
89
-		return 0, errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
90
-	}
91
-	f.readSeqNo++
83
+		var seqNo int32
84
+		binary.Read(buf, binary.LittleEndian, seqNo)
85
+		if seqNo != f.readSeqNo {
86
+			return errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
87
+		}
88
+		f.readSeqNo++
92 89
 
93
-	data := buf.Bytes()[:int(messageLength)-4-4-4]
94
-	checksum := binary.LittleEndian.Uint32(buf.Bytes()[int(messageLength)-4-4-4:])
95
-	if checksum != sum.Sum32() {
96
-		return 0, errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
90
+		data := buf.Bytes()[:int(messageLength)-4-4-4]
91
+		checksum := binary.LittleEndian.Uint32(buf.Bytes()[int(messageLength)-4-4-4:])
92
+		if checksum != sum.Sum32() {
93
+			return errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
97 94
 
98
-	}
99
-	f.readBuf.Write(data)
95
+		}
96
+		f.Buffer.Write(data)
100 97
 
101
-	return f.flush(p)
98
+		return nil
99
+	})
102 100
 }
103 101
 
104 102
 func (f *FrameRWC) Close() error {
@@ -109,28 +107,11 @@ func (f *FrameRWC) Addr() *net.TCPAddr {
109 107
 	return f.conn.Addr()
110 108
 }
111 109
 
112
-func (f *FrameRWC) flush(p []byte) (int, error) {
113
-	sizeToRead := len(p)
114
-	if f.readBuf.Len() < sizeToRead {
115
-		sizeToRead = f.readBuf.Len()
116
-	}
117
-
118
-	data := f.readBuf.Bytes()
119
-	copy(p, data[:sizeToRead])
120
-	if sizeToRead == f.readBuf.Len() {
121
-		f.readBuf.Reset()
122
-	} else {
123
-		f.readBuf = bytes.NewBuffer(data[sizeToRead:])
124
-	}
125
-
126
-	return sizeToRead, nil
127
-}
128
-
129 110
 func NewFrameRWC(conn wrappers.ReadWriteCloserWithAddr, seqNo int32) wrappers.ReadWriteCloserWithAddr {
130 111
 	return &FrameRWC{
131
-		conn:       conn,
132
-		readSeqNo:  seqNo,
133
-		writeSeqNo: seqNo,
134
-		readBuf:    &bytes.Buffer{},
112
+		BufferedReader: wrappers.NewBufferedReader(),
113
+		conn:           conn,
114
+		readSeqNo:      seqNo,
115
+		writeSeqNo:     seqNo,
135 116
 	}
136 117
 }

+ 17
- 46
wrappers/blockcipherrwc.go Целия файл

@@ -1,7 +1,6 @@
1 1
 package wrappers
2 2
 
3 3
 import (
4
-	"bytes"
5 4
 	"crypto/aes"
6 5
 	"crypto/cipher"
7 6
 	"net"
@@ -10,7 +9,7 @@ import (
10 9
 )
11 10
 
12 11
 type BlockCipherReadWriteCloserWithAddr struct {
13
-	buf *bytes.Buffer
12
+	BufferedReader
14 13
 
15 14
 	conn      ReadWriteCloserWithAddr
16 15
 	encryptor cipher.BlockMode
@@ -18,20 +17,19 @@ type BlockCipherReadWriteCloserWithAddr struct {
18 17
 }
19 18
 
20 19
 func (c *BlockCipherReadWriteCloserWithAddr) Read(p []byte) (int, error) {
21
-	if c.buf.Len() > 0 {
22
-		return c.flush(p)
23
-	}
24
-
25
-	for c.buf.Len() == 0 || c.buf.Len()%aes.BlockSize != 0 {
26
-		n, err := c.conn.Read(p)
27
-		if err != nil {
28
-			return 0, errors.Annotate(err, "Cannot read from socket")
20
+	return c.BufferedRead(p, func() error {
21
+		bufferLength := c.Buffer.Len()
22
+		for bufferLength%aes.BlockSize != 0 || bufferLength == 0 {
23
+			n, err := c.conn.Read(p)
24
+			if err != nil {
25
+				return errors.Annotate(err, "Cannot read from socket")
26
+			}
27
+			c.Buffer.Write(p[:n])
29 28
 		}
30
-		c.buf.Write(p[:n])
31
-	}
32
-	c.decryptor.CryptBlocks(c.buf.Bytes(), c.buf.Bytes())
29
+		c.decryptor.CryptBlocks(c.Buffer.Bytes(), c.Buffer.Bytes())
33 30
 
34
-	return c.flush(p)
31
+		return nil
32
+	})
35 33
 }
36 34
 
37 35
 func (c *BlockCipherReadWriteCloserWithAddr) Write(p []byte) (int, error) {
@@ -39,19 +37,13 @@ func (c *BlockCipherReadWriteCloserWithAddr) Write(p []byte) (int, error) {
39 37
 		return 0, errors.Errorf("Incorrect block size %d", len(p))
40 38
 	}
41 39
 
42
-	buf := getBuffer()
43
-	defer putBuffer(buf)
44
-	buf.Grow(len(p))
45
-	buf.Write(p)
46
-
47
-	encrypted := buf.Bytes()
40
+	encrypted := make([]byte, len(p))
48 41
 	c.encryptor.CryptBlocks(encrypted, p)
49 42
 
50 43
 	return c.conn.Write(encrypted)
51 44
 }
52 45
 
53 46
 func (c *BlockCipherReadWriteCloserWithAddr) Close() error {
54
-	defer putBuffer(c.buf)
55 47
 	return c.conn.Close()
56 48
 }
57 49
 
@@ -59,32 +51,11 @@ func (c *BlockCipherReadWriteCloserWithAddr) Addr() *net.TCPAddr {
59 51
 	return c.conn.Addr()
60 52
 }
61 53
 
62
-func (c *BlockCipherReadWriteCloserWithAddr) flush(p []byte) (int, error) {
63
-	sizeToRead := len(p)
64
-	if c.buf.Len() < sizeToRead {
65
-		sizeToRead = c.buf.Len()
66
-	}
67
-
68
-	data := c.buf.Bytes()
69
-	copy(p, data[:sizeToRead])
70
-	if sizeToRead == c.buf.Len() {
71
-		c.buf.Reset()
72
-	} else {
73
-		newBuf := getBuffer()
74
-		newBuf.Write(data[sizeToRead:])
75
-
76
-		putBuffer(c.buf)
77
-		c.buf = newBuf
78
-	}
79
-
80
-	return sizeToRead, nil
81
-}
82
-
83 54
 func NewBlockCipherRWC(conn ReadWriteCloserWithAddr, encryptor, decryptor cipher.BlockMode) ReadWriteCloserWithAddr {
84 55
 	return &BlockCipherReadWriteCloserWithAddr{
85
-		buf:       getBuffer(),
86
-		conn:      conn,
87
-		encryptor: encryptor,
88
-		decryptor: decryptor,
56
+		BufferedReader: NewBufferedReader(),
57
+		conn:           conn,
58
+		encryptor:      encryptor,
59
+		decryptor:      decryptor,
89 60
 	}
90 61
 }

+ 38
- 0
wrappers/buffered_reader.go Целия файл

@@ -0,0 +1,38 @@
1
+package wrappers
2
+
3
+import "bytes"
4
+
5
+type BufferedReader struct {
6
+	Buffer *bytes.Buffer
7
+}
8
+
9
+func (b *BufferedReader) BufferedRead(p []byte, callback func() error) (int, error) {
10
+	if b.Buffer.Len() > 0 {
11
+		return b.flush(p)
12
+	}
13
+	if err := callback(); err != nil {
14
+		return 0, err
15
+	}
16
+	return b.flush(p)
17
+}
18
+
19
+func (b *BufferedReader) flush(p []byte) (int, error) {
20
+	sizeToRead := len(p)
21
+	if b.Buffer.Len() < sizeToRead {
22
+		sizeToRead = b.Buffer.Len()
23
+	}
24
+
25
+	data := b.Buffer.Bytes()
26
+	copy(p, data[:sizeToRead])
27
+	if sizeToRead == b.Buffer.Len() {
28
+		b.Buffer.Reset()
29
+	} else {
30
+		b.Buffer = bytes.NewBuffer(data[sizeToRead:])
31
+	}
32
+
33
+	return sizeToRead, nil
34
+}
35
+
36
+func NewBufferedReader() BufferedReader {
37
+	return BufferedReader{Buffer: &bytes.Buffer{}}
38
+}

Loading…
Отказ
Запис