9seconds 7 лет назад
Родитель
Сommit
420159a7d3
2 измененных файлов: 145 добавлений и 5 удалений
  1. 138
    0
      mtproto/wrappers/frame.go
  2. 7
    5
      wrappers/blockcipherrwc.go

+ 138
- 0
mtproto/wrappers/frame.go Просмотреть файл

1
+package wrappers
2
+
3
+import (
4
+	"bytes"
5
+	"crypto/aes"
6
+	"encoding/binary"
7
+	"hash/crc32"
8
+	"io"
9
+
10
+	"github.com/juju/errors"
11
+
12
+	"github.com/9seconds/mtg/mtproto"
13
+)
14
+
15
+// Frame: { MessageLength(4) | SequenceNumber(4) | Message(???) | CRC32(4) [| padding(4), ...] }
16
+const (
17
+	frameRWCMinMessageLength = 12
18
+	frameRWCMaxMessageLength = 16777216
19
+)
20
+
21
+var frameRWCPadding = [4]byte{0x04, 0x00, 0x00, 0x00}
22
+
23
+type FrameRWC struct {
24
+	conn mtproto.BytesRWC
25
+
26
+	readSeqNo  int32
27
+	writeSeqNo int32
28
+	readBuf    *bytes.Buffer
29
+}
30
+
31
+func (f *FrameRWC) Write(buf *bytes.Buffer) (int, error) {
32
+	writeBuf := mtproto.GetBuffer()
33
+	defer mtproto.ReturnBuffer(writeBuf)
34
+
35
+	// 4 - len bytes
36
+	// 4 - seq bytes
37
+	// . - message
38
+	// 4 - crc32
39
+	messageLength := 4 + 4 + buf.Len() + 4
40
+	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
41
+	writeBuf.Grow(messageLength + paddingLength)
42
+
43
+	binary.Write(writeBuf, binary.LittleEndian, uint32(messageLength))
44
+	binary.Write(writeBuf, binary.LittleEndian, f.writeSeqNo)
45
+	writeBuf.Write(buf.Bytes())
46
+	f.writeSeqNo++
47
+
48
+	checksum := crc32.ChecksumIEEE(writeBuf.Bytes())
49
+	binary.Write(writeBuf, binary.LittleEndian, checksum)
50
+	writeBuf.Write(bytes.Repeat(frameRWCPadding[:], paddingLength/4))
51
+
52
+	return f.conn.Write(writeBuf)
53
+}
54
+
55
+func (f *FrameRWC) Read(p []byte) (int, error) {
56
+	if f.readBuf.Len() > 0 {
57
+		return f.flush(p)
58
+	}
59
+
60
+	buf := mtproto.GetBuffer()
61
+	defer mtproto.ReturnBuffer(buf)
62
+
63
+	for {
64
+		buf.Reset()
65
+		if _, err := io.CopyN(buf, f.conn, 4); err != nil {
66
+			return 0, errors.Annotate(err, "Cannot read frame padding")
67
+		}
68
+		if !bytes.Equal(buf.Bytes(), frameRWCPadding[:]) {
69
+			break
70
+		}
71
+	}
72
+
73
+	messageLength := binary.LittleEndian.Uint32(buf.Bytes())
74
+	if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
75
+		return 0, errors.Errorf("Incorrect frame message length %d", messageLength)
76
+	}
77
+	sum := crc32.NewIEEE()
78
+	sum.Write(buf.Bytes())
79
+
80
+	buf.Reset()
81
+	buf.Grow(int(messageLength) - 4) // -4 because we already read the first number
82
+	if _, err := io.CopyN(buf, f.conn, int64(messageLength)-4); err != nil {
83
+		return 0, errors.Annotate(err, "Cannot read the message frame")
84
+	}
85
+	sum.Write(buf.Bytes())
86
+
87
+	var seqNo int32
88
+	binary.Read(buf, binary.LittleEndian, seqNo)
89
+	if seqNo != f.readSeqNo {
90
+		return 0, errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
91
+	}
92
+	f.readSeqNo++
93
+
94
+	data := buf.Bytes()[:int(messageLength)-4-4-4]
95
+	checksum := binary.LittleEndian.Uint32(buf.Bytes()[int(messageLength)-4-4-4:])
96
+	if checksum != sum.Sum32() {
97
+		return 0, errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
98
+
99
+	}
100
+	f.readBuf.Write(data)
101
+
102
+	return f.flush(p)
103
+}
104
+
105
+func (f *FrameRWC) Close() error {
106
+	defer mtproto.ReturnBuffer(f.readBuf)
107
+	return f.conn.Close()
108
+}
109
+
110
+func (f *FrameRWC) flush(p []byte) (int, error) {
111
+	sizeToRead := len(p)
112
+	if f.readBuf.Len() < sizeToRead {
113
+		sizeToRead = f.readBuf.Len()
114
+	}
115
+
116
+	data := f.readBuf.Bytes()
117
+	copy(p, data[:sizeToRead])
118
+	if sizeToRead == f.readBuf.Len() {
119
+		f.readBuf.Reset()
120
+	} else {
121
+		newBuf := mtproto.GetBuffer()
122
+		newBuf.Write(data[sizeToRead:])
123
+
124
+		mtproto.ReturnBuffer(f.readBuf)
125
+		f.readBuf = newBuf
126
+	}
127
+
128
+	return sizeToRead, nil
129
+}
130
+
131
+func NewFrameRWC(conn mtproto.BytesRWC, seqNo int32) mtproto.BytesRWC {
132
+	return &FrameRWC{
133
+		conn:       conn,
134
+		readSeqNo:  seqNo,
135
+		writeSeqNo: seqNo,
136
+		readBuf:    mtproto.GetBuffer(),
137
+	}
138
+}

+ 7
- 5
wrappers/blockcipherrwc.go Просмотреть файл

29
 		}
29
 		}
30
 		c.buf.Write(p[:n])
30
 		c.buf.Write(p[:n])
31
 	}
31
 	}
32
+	c.decryptor.CryptBlocks(c.buf.Bytes(), c.buf.Bytes())
32
 
33
 
33
 	return c.flush(p)
34
 	return c.flush(p)
34
 }
35
 }
59
 	if c.buf.Len() < sizeToRead {
60
 	if c.buf.Len() < sizeToRead {
60
 		sizeToRead = c.buf.Len()
61
 		sizeToRead = c.buf.Len()
61
 	}
62
 	}
62
-	sizeToRead = aes.BlockSize * (sizeToRead / aes.BlockSize)
63
 
63
 
64
-	c.decryptor.CryptBlocks(p, c.buf.Bytes()[:sizeToRead])
64
+	data := c.buf.Bytes()
65
+	copy(p, data[:sizeToRead])
65
 	if sizeToRead == c.buf.Len() {
66
 	if sizeToRead == c.buf.Len() {
66
 		c.buf.Reset()
67
 		c.buf.Reset()
67
 	} else {
68
 	} else {
68
-		leftover := c.buf.Bytes()[sizeToRead:]
69
+		newBuf := getBuffer()
70
+		newBuf.Write(data[sizeToRead:])
71
+
69
 		putBuffer(c.buf)
72
 		putBuffer(c.buf)
70
-		c.buf = getBuffer()
71
-		c.buf.Write(leftover)
73
+		c.buf = newBuf
72
 	}
74
 	}
73
 
75
 
74
 	return sizeToRead, nil
76
 	return sizeToRead, nil

Загрузка…
Отмена
Сохранить