9seconds 7 лет назад
Родитель
Сommit
4c4671ad3b
2 измененных файлов: 26 добавлений и 14 удалений
  1. 18
    9
      mtproto/wrappers/frame.go
  2. 8
    5
      telegram/middle.go

+ 18
- 9
mtproto/wrappers/frame.go Просмотреть файл

@@ -6,6 +6,7 @@ import (
6 6
 	"encoding/binary"
7 7
 	"hash/crc32"
8 8
 	"io"
9
+	"io/ioutil"
9 10
 	"net"
10 11
 
11 12
 	"github.com/juju/errors"
@@ -56,9 +57,13 @@ func (f *FrameRWC) Write(buf []byte) (int, error) {
56 57
 func (f *FrameRWC) Read(p []byte) (int, error) {
57 58
 	return f.BufferedRead(p, func() error {
58 59
 		buf := &bytes.Buffer{}
60
+		sum := crc32.NewIEEE()
61
+		writer := io.MultiWriter(buf, sum)
62
+
59 63
 		for {
60 64
 			buf.Reset()
61
-			if _, err := io.CopyN(buf, f.conn, 4); err != nil {
65
+			sum.Reset()
66
+			if _, err := io.CopyN(writer, f.conn, 4); err != nil {
62 67
 				return errors.Annotate(err, "Cannot read frame padding")
63 68
 			}
64 69
 			if !bytes.Equal(buf.Bytes(), frameRWCPadding[:]) {
@@ -70,25 +75,29 @@ func (f *FrameRWC) Read(p []byte) (int, error) {
70 75
 		if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
71 76
 			return errors.Errorf("Incorrect frame message length %d", messageLength)
72 77
 		}
73
-		sum := crc32.NewIEEE()
74
-		sum.Write(buf.Bytes())
75 78
 
76 79
 		buf.Reset()
77
-		buf.Grow(int(messageLength) - 4) // -4 because we already read the first number
78
-		if _, err := io.CopyN(buf, f.conn, int64(messageLength)-4); err != nil {
80
+		buf.Grow(int(messageLength) - 4 - 4)
81
+		if _, err := io.CopyN(writer, f.conn, int64(messageLength)-4-4); err != nil {
79 82
 			return errors.Annotate(err, "Cannot read the message frame")
80 83
 		}
81
-		sum.Write(buf.Bytes())
82 84
 
83 85
 		var seqNo int32
84
-		binary.Read(buf, binary.LittleEndian, seqNo)
86
+		binary.Read(buf, binary.LittleEndian, &seqNo)
85 87
 		if seqNo != f.readSeqNo {
86 88
 			return errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
87 89
 		}
88 90
 		f.readSeqNo++
89 91
 
90
-		data := buf.Bytes()[:int(messageLength)-4-4-4]
91
-		checksum := binary.LittleEndian.Uint32(buf.Bytes()[int(messageLength)-4-4-4:])
92
+		data, _ := ioutil.ReadAll(buf)
93
+		buf.Reset()
94
+		// write to buf, not to writer. This is because we are going to fetch
95
+		// crc32 checksum.
96
+		if _, err := io.CopyN(buf, f.conn, 4); err != nil {
97
+			return errors.Annotate(err, "Cannot read checksum")
98
+		}
99
+		checksum := binary.LittleEndian.Uint32(buf.Bytes())
100
+
92 101
 		if checksum != sum.Sum32() {
93 102
 			return errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
94 103
 

+ 8
- 5
telegram/middle.go Просмотреть файл

@@ -2,7 +2,6 @@ package telegram
2 2
 
3 3
 import (
4 4
 	"io"
5
-	"io/ioutil"
6 5
 	"net"
7 6
 	"net/http"
8 7
 	"sync"
@@ -84,11 +83,13 @@ func (t *middleTelegram) sendRPCNonceRequest(conn io.Writer) (*rpc.RPCNonceReque
84 83
 }
85 84
 
86 85
 func (t *middleTelegram) receiveRPCNonceResponse(conn io.Reader, req *rpc.RPCNonceRequest) (*rpc.RPCNonceResponse, error) {
87
-	ans, err := ioutil.ReadAll(conn)
86
+	var ans [128]byte
87
+
88
+	n, err := conn.Read(ans[:])
88 89
 	if err != nil {
89 90
 		return nil, errors.Annotate(err, "Cannot read RPC nonce response")
90 91
 	}
91
-	rpcNonceResp, err := rpc.NewRPCNonceResponse(ans)
92
+	rpcNonceResp, err := rpc.NewRPCNonceResponse(ans[:n])
92 93
 	if err != nil {
93 94
 		return nil, errors.Annotate(err, "Cannot initialize RPC nonce response")
94 95
 	}
@@ -109,11 +110,13 @@ func (t *middleTelegram) sendRPCHandshakeRequest(conn io.Writer) (*rpc.RPCHandsh
109 110
 }
110 111
 
111 112
 func (t *middleTelegram) receiveRPCHandshakeResponse(conn io.Reader, req *rpc.RPCHandshakeRequest) (*rpc.RPCHandshakeResponse, error) {
112
-	ans, err := ioutil.ReadAll(conn)
113
+	var ans [128]byte
114
+
115
+	n, err := conn.Read(ans[:])
113 116
 	if err != nil {
114 117
 		return nil, errors.Annotate(err, "Cannot read RPC handshake response")
115 118
 	}
116
-	rpcHandshakeResp, err := rpc.NewRPCHandshakeResponse(ans)
119
+	rpcHandshakeResp, err := rpc.NewRPCHandshakeResponse(ans[:n])
117 120
 	if err != nil {
118 121
 		return nil, errors.Annotate(err, "Cannot initialize RPC handshake response")
119 122
 	}

Загрузка…
Отмена
Сохранить