9seconds 7 лет назад
Родитель
Сommit
4c4671ad3b
2 измененных файлов: 26 добавлений и 14 удалений
  1. 18
    9
      mtproto/wrappers/frame.go
  2. 8
    5
      telegram/middle.go

+ 18
- 9
mtproto/wrappers/frame.go Просмотреть файл

6
 	"encoding/binary"
6
 	"encoding/binary"
7
 	"hash/crc32"
7
 	"hash/crc32"
8
 	"io"
8
 	"io"
9
+	"io/ioutil"
9
 	"net"
10
 	"net"
10
 
11
 
11
 	"github.com/juju/errors"
12
 	"github.com/juju/errors"
56
 func (f *FrameRWC) Read(p []byte) (int, error) {
57
 func (f *FrameRWC) Read(p []byte) (int, error) {
57
 	return f.BufferedRead(p, func() error {
58
 	return f.BufferedRead(p, func() error {
58
 		buf := &bytes.Buffer{}
59
 		buf := &bytes.Buffer{}
60
+		sum := crc32.NewIEEE()
61
+		writer := io.MultiWriter(buf, sum)
62
+
59
 		for {
63
 		for {
60
 			buf.Reset()
64
 			buf.Reset()
61
-			if _, err := io.CopyN(buf, f.conn, 4); err != nil {
65
+			sum.Reset()
66
+			if _, err := io.CopyN(writer, f.conn, 4); err != nil {
62
 				return errors.Annotate(err, "Cannot read frame padding")
67
 				return errors.Annotate(err, "Cannot read frame padding")
63
 			}
68
 			}
64
 			if !bytes.Equal(buf.Bytes(), frameRWCPadding[:]) {
69
 			if !bytes.Equal(buf.Bytes(), frameRWCPadding[:]) {
70
 		if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
75
 		if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
71
 			return errors.Errorf("Incorrect frame message length %d", messageLength)
76
 			return errors.Errorf("Incorrect frame message length %d", messageLength)
72
 		}
77
 		}
73
-		sum := crc32.NewIEEE()
74
-		sum.Write(buf.Bytes())
75
 
78
 
76
 		buf.Reset()
79
 		buf.Reset()
77
-		buf.Grow(int(messageLength) - 4) // -4 because we already read the first number
78
-		if _, err := io.CopyN(buf, f.conn, int64(messageLength)-4); err != nil {
80
+		buf.Grow(int(messageLength) - 4 - 4)
81
+		if _, err := io.CopyN(writer, f.conn, int64(messageLength)-4-4); err != nil {
79
 			return errors.Annotate(err, "Cannot read the message frame")
82
 			return errors.Annotate(err, "Cannot read the message frame")
80
 		}
83
 		}
81
-		sum.Write(buf.Bytes())
82
 
84
 
83
 		var seqNo int32
85
 		var seqNo int32
84
-		binary.Read(buf, binary.LittleEndian, seqNo)
86
+		binary.Read(buf, binary.LittleEndian, &seqNo)
85
 		if seqNo != f.readSeqNo {
87
 		if seqNo != f.readSeqNo {
86
 			return errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
88
 			return errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
87
 		}
89
 		}
88
 		f.readSeqNo++
90
 		f.readSeqNo++
89
 
91
 
90
-		data := buf.Bytes()[:int(messageLength)-4-4-4]
91
-		checksum := binary.LittleEndian.Uint32(buf.Bytes()[int(messageLength)-4-4-4:])
92
+		data, _ := ioutil.ReadAll(buf)
93
+		buf.Reset()
94
+		// write to buf, not to writer. This is because we are going to fetch
95
+		// crc32 checksum.
96
+		if _, err := io.CopyN(buf, f.conn, 4); err != nil {
97
+			return errors.Annotate(err, "Cannot read checksum")
98
+		}
99
+		checksum := binary.LittleEndian.Uint32(buf.Bytes())
100
+
92
 		if checksum != sum.Sum32() {
101
 		if checksum != sum.Sum32() {
93
 			return errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
102
 			return errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
94
 
103
 

+ 8
- 5
telegram/middle.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"io"
4
 	"io"
5
-	"io/ioutil"
6
 	"net"
5
 	"net"
7
 	"net/http"
6
 	"net/http"
8
 	"sync"
7
 	"sync"
84
 }
83
 }
85
 
84
 
86
 func (t *middleTelegram) receiveRPCNonceResponse(conn io.Reader, req *rpc.RPCNonceRequest) (*rpc.RPCNonceResponse, error) {
85
 func (t *middleTelegram) receiveRPCNonceResponse(conn io.Reader, req *rpc.RPCNonceRequest) (*rpc.RPCNonceResponse, error) {
87
-	ans, err := ioutil.ReadAll(conn)
86
+	var ans [128]byte
87
+
88
+	n, err := conn.Read(ans[:])
88
 	if err != nil {
89
 	if err != nil {
89
 		return nil, errors.Annotate(err, "Cannot read RPC nonce response")
90
 		return nil, errors.Annotate(err, "Cannot read RPC nonce response")
90
 	}
91
 	}
91
-	rpcNonceResp, err := rpc.NewRPCNonceResponse(ans)
92
+	rpcNonceResp, err := rpc.NewRPCNonceResponse(ans[:n])
92
 	if err != nil {
93
 	if err != nil {
93
 		return nil, errors.Annotate(err, "Cannot initialize RPC nonce response")
94
 		return nil, errors.Annotate(err, "Cannot initialize RPC nonce response")
94
 	}
95
 	}
109
 }
110
 }
110
 
111
 
111
 func (t *middleTelegram) receiveRPCHandshakeResponse(conn io.Reader, req *rpc.RPCHandshakeRequest) (*rpc.RPCHandshakeResponse, error) {
112
 func (t *middleTelegram) receiveRPCHandshakeResponse(conn io.Reader, req *rpc.RPCHandshakeRequest) (*rpc.RPCHandshakeResponse, error) {
112
-	ans, err := ioutil.ReadAll(conn)
113
+	var ans [128]byte
114
+
115
+	n, err := conn.Read(ans[:])
113
 	if err != nil {
116
 	if err != nil {
114
 		return nil, errors.Annotate(err, "Cannot read RPC handshake response")
117
 		return nil, errors.Annotate(err, "Cannot read RPC handshake response")
115
 	}
118
 	}
116
-	rpcHandshakeResp, err := rpc.NewRPCHandshakeResponse(ans)
119
+	rpcHandshakeResp, err := rpc.NewRPCHandshakeResponse(ans[:n])
117
 	if err != nil {
120
 	if err != nil {
118
 		return nil, errors.Annotate(err, "Cannot initialize RPC handshake response")
121
 		return nil, errors.Annotate(err, "Cannot initialize RPC handshake response")
119
 	}
122
 	}

Загрузка…
Отмена
Сохранить