Parcourir la source

Move frame wrapper

tags/0.9
9seconds il y a 7 ans
Parent
révision
1d93bef913
3 fichiers modifiés avec 160 ajouts et 143 suppressions
  1. 0
    134
      mtproto/wrappers/frame.go
  2. 9
    9
      wrappers/mtproto_abridged.go
  3. 151
    0
      wrappers/mtproto_frame.go

+ 0
- 134
mtproto/wrappers/frame.go Voir le fichier

@@ -1,134 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"bytes"
5
-	"crypto/aes"
6
-	"encoding/binary"
7
-	"hash/crc32"
8
-	"io"
9
-	"io/ioutil"
10
-	"net"
11
-
12
-	"github.com/juju/errors"
13
-
14
-	"github.com/9seconds/mtg/wrappers"
15
-)
16
-
17
-// Frame: { MessageLength(4) | SequenceNumber(4) | Message(???) | CRC32(4) [| padding(4), ...] }
18
-const (
19
-	frameRWCMinMessageLength = 12
20
-	frameRWCMaxMessageLength = 16777216
21
-)
22
-
23
-var frameRWCPadding = []byte{0x04, 0x00, 0x00, 0x00}
24
-
25
-type FrameRWC struct {
26
-	wrappers.BufferedReader
27
-
28
-	conn       wrappers.ReadWriteCloserWithAddr
29
-	readSeqNo  int32
30
-	writeSeqNo int32
31
-}
32
-
33
-func (f *FrameRWC) Write(buf []byte) (int, error) {
34
-	writeBuf := &bytes.Buffer{}
35
-
36
-	// 4 - len bytes
37
-	// 4 - seq bytes
38
-	// . - message
39
-	// 4 - crc32
40
-	messageLength := 4 + 4 + len(buf) + 4
41
-	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
42
-	writeBuf.Grow(messageLength + paddingLength)
43
-
44
-	binary.Write(writeBuf, binary.LittleEndian, uint32(messageLength))
45
-	binary.Write(writeBuf, binary.LittleEndian, f.writeSeqNo)
46
-	writeBuf.Write(buf)
47
-	f.writeSeqNo++
48
-
49
-	checksum := crc32.ChecksumIEEE(writeBuf.Bytes())
50
-	binary.Write(writeBuf, binary.LittleEndian, checksum)
51
-	writeBuf.Write(bytes.Repeat(frameRWCPadding, paddingLength/4))
52
-
53
-	_, err := f.conn.Write(writeBuf.Bytes())
54
-	return len(buf), err
55
-}
56
-
57
-func (f *FrameRWC) Read(p []byte) (int, error) {
58
-	return f.BufferedRead(p, func() error {
59
-		buf := &bytes.Buffer{}
60
-		sum := crc32.NewIEEE()
61
-		writer := io.MultiWriter(buf, sum)
62
-
63
-		for {
64
-			buf.Reset()
65
-			sum.Reset()
66
-			if _, err := io.CopyN(writer, f.conn, 4); err != nil {
67
-				return errors.Annotate(err, "Cannot read frame padding")
68
-			}
69
-			if !bytes.Equal(buf.Bytes(), frameRWCPadding) {
70
-				break
71
-			}
72
-		}
73
-
74
-		messageLength := binary.LittleEndian.Uint32(buf.Bytes())
75
-		if messageLength%4 != 0 || messageLength < frameRWCMinMessageLength || messageLength > frameRWCMaxMessageLength {
76
-			return errors.Errorf("Incorrect frame message length %d", messageLength)
77
-		}
78
-
79
-		buf.Reset()
80
-		buf.Grow(int(messageLength) - 4 - 4)
81
-		if _, err := io.CopyN(writer, f.conn, int64(messageLength)-4-4); err != nil {
82
-			return errors.Annotate(err, "Cannot read the message frame")
83
-		}
84
-
85
-		var seqNo int32
86
-		binary.Read(buf, binary.LittleEndian, &seqNo)
87
-		if seqNo != f.readSeqNo {
88
-			return errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, f.readSeqNo)
89
-		}
90
-		f.readSeqNo++
91
-
92
-		data, _ := ioutil.ReadAll(buf)
93
-		buf.Reset()
94
-		// write to buf, not to writer. This is because we are going to fetch
95
-		// crc32 checksum.
96
-		if _, err := io.CopyN(buf, f.conn, 4); err != nil {
97
-			return errors.Annotate(err, "Cannot read checksum")
98
-		}
99
-		checksum := binary.LittleEndian.Uint32(buf.Bytes())
100
-
101
-		if checksum != sum.Sum32() {
102
-			return errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
103
-
104
-		}
105
-		f.Buffer.Write(data)
106
-
107
-		return nil
108
-	})
109
-}
110
-
111
-func (f *FrameRWC) Close() error {
112
-	return f.conn.Close()
113
-}
114
-
115
-func (f *FrameRWC) LocalAddr() *net.TCPAddr {
116
-	return f.conn.LocalAddr()
117
-}
118
-
119
-func (f *FrameRWC) RemoteAddr() *net.TCPAddr {
120
-	return f.conn.RemoteAddr()
121
-}
122
-
123
-func (f *FrameRWC) SocketID() string {
124
-	return f.conn.SocketID()
125
-}
126
-
127
-func NewFrameRWC(conn wrappers.ReadWriteCloserWithAddr, seqNo int32) wrappers.ReadWriteCloserWithAddr {
128
-	return &FrameRWC{
129
-		BufferedReader: wrappers.NewBufferedReader(),
130
-		conn:           conn,
131
-		readSeqNo:      seqNo,
132
-		writeSeqNo:     seqNo,
133
-	}
134
-}

+ 9
- 9
wrappers/mtproto_abridged.go Voir le fichier

@@ -12,9 +12,9 @@ import (
12 12
 )
13 13
 
14 14
 const (
15
-	abridgedSmallPacketLength = 0x7f
16
-	abridgedQuickAckLength    = 0x80
17
-	abridgedLargePacketLength = 16777216 // 256 ^ 3
15
+	mtprotoAbridgedSmallPacketLength = 0x7f
16
+	mtprotoAbridgedQuickAckLength    = 0x80
17
+	mtprotoAbridgedLargePacketLength = 16777216 // 256 ^ 3
18 18
 )
19 19
 
20 20
 type MTProtoAbridged struct {
@@ -46,13 +46,13 @@ func (m *MTProtoAbridged) Read() ([]byte, error) {
46 46
 		"counter", m.readCounter,
47 47
 	)
48 48
 
49
-	if msgLength >= abridgedQuickAckLength {
49
+	if msgLength >= mtprotoAbridgedQuickAckLength {
50 50
 		m.opts.ReadHacks.QuickAck = true
51
-		msgLength -= abridgedQuickAckLength
51
+		msgLength -= mtprotoAbridgedQuickAckLength
52 52
 	}
53 53
 
54 54
 	msgLength32 := uint32(msgLength)
55
-	if msgLength == abridgedSmallPacketLength {
55
+	if msgLength == mtprotoAbridgedSmallPacketLength {
56 56
 		if _, err := io.CopyN(buf, m.conn, 3); err != nil {
57 57
 			return nil, errors.Annotate(err, "Cannot read the correct message length")
58 58
 		}
@@ -96,19 +96,19 @@ func (m *MTProtoAbridged) Write(p []byte) (int, error) {
96 96
 
97 97
 	packetLength := len(p) / 4
98 98
 	switch {
99
-	case packetLength < abridgedSmallPacketLength:
99
+	case packetLength < mtprotoAbridgedSmallPacketLength:
100 100
 		newData := append([]byte{byte(packetLength)}, p...)
101 101
 
102 102
 		m.writeCounter++
103 103
 		return m.conn.Write(newData)
104 104
 
105
-	case packetLength < abridgedLargePacketLength:
105
+	case packetLength < mtprotoAbridgedLargePacketLength:
106 106
 		length24 := utils.ToUint24(uint32(packetLength))
107 107
 
108 108
 		buf := &bytes.Buffer{}
109 109
 		buf.Grow(1 + 3 + len(p))
110 110
 
111
-		buf.WriteByte(byte(abridgedSmallPacketLength))
111
+		buf.WriteByte(byte(mtprotoAbridgedSmallPacketLength))
112 112
 		buf.Write(length24[:])
113 113
 		buf.Write(p)
114 114
 

+ 151
- 0
wrappers/mtproto_frame.go Voir le fichier

@@ -0,0 +1,151 @@
1
+package wrappers
2
+
3
+import (
4
+	"bytes"
5
+	"crypto/aes"
6
+	"encoding/binary"
7
+	"hash/crc32"
8
+	"io"
9
+	"io/ioutil"
10
+	"net"
11
+
12
+	"github.com/juju/errors"
13
+)
14
+
15
+const (
16
+	mtprotoFrameMinMessageLength = 12
17
+	mtprotoFrameMaxMessageLength = 16777216
18
+)
19
+
20
+var mtprotoFramePadding = []byte{0x04, 0x00, 0x00, 0x00}
21
+
22
+type MTProtoFrame struct {
23
+	conn       WrapStreamReadWriteCloser
24
+	readSeqNo  int32
25
+	writeSeqNo int32
26
+}
27
+
28
+func (m *MTProtoFrame) Read() ([]byte, error) {
29
+	buf := &bytes.Buffer{}
30
+	sum := crc32.NewIEEE()
31
+	writer := io.MultiWriter(buf, sum)
32
+
33
+	for {
34
+		buf.Reset()
35
+		sum.Reset()
36
+		if _, err := io.CopyN(writer, m.conn, 4); err != nil {
37
+			return nil, errors.Annotate(err, "Cannot read frame padding")
38
+		}
39
+		if !bytes.Equal(buf.Bytes(), mtprotoFramePadding) {
40
+			break
41
+		}
42
+	}
43
+
44
+	messageLength := binary.LittleEndian.Uint32(buf.Bytes())
45
+	m.LogDebug("Read MTProto frame",
46
+		"messageLength", messageLength,
47
+		"sequence_number", m.readSeqNo,
48
+	)
49
+	if messageLength%4 != 0 || messageLength < mtprotoFrameMinMessageLength || messageLength > mtprotoFrameMaxMessageLength {
50
+		return nil, errors.Errorf("Incorrect frame message length %d", messageLength)
51
+	}
52
+
53
+	buf.Reset()
54
+	buf.Grow(int(messageLength) - 4 - 4)
55
+	if _, err := io.CopyN(writer, m.conn, int64(messageLength)-4-4); err != nil {
56
+		return nil, errors.Annotate(err, "Cannot read the message frame")
57
+	}
58
+
59
+	var seqNo int32
60
+	binary.Read(buf, binary.LittleEndian, &seqNo)
61
+	if seqNo != m.readSeqNo {
62
+		return nil, errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, m.readSeqNo)
63
+	}
64
+
65
+	data, _ := ioutil.ReadAll(buf)
66
+	buf.Reset()
67
+	// write to buf, not to writer. This is because we are going to fetch
68
+	// crc32 checksum.
69
+	if _, err := io.CopyN(buf, m.conn, 4); err != nil {
70
+		return nil, errors.Annotate(err, "Cannot read checksum")
71
+	}
72
+
73
+	checksum := binary.LittleEndian.Uint32(buf.Bytes())
74
+	if checksum != sum.Sum32() {
75
+		return nil, errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum)
76
+	}
77
+
78
+	m.LogDebug("Read MTProto frame",
79
+		"messageLength", messageLength,
80
+		"sequence_number", m.readSeqNo,
81
+		"dataLength", len(data),
82
+		"checksum", checksum,
83
+	)
84
+	m.readSeqNo++
85
+
86
+	return data, nil
87
+}
88
+
89
+func (m *MTProtoFrame) Write(p []byte) (int, error) {
90
+	messageLength := 4 + 4 + len(p) + 4
91
+	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
92
+
93
+	buf := &bytes.Buffer{}
94
+	buf.Grow(messageLength + paddingLength)
95
+
96
+	binary.Write(buf, binary.LittleEndian, uint32(messageLength))
97
+	binary.Write(buf, binary.LittleEndian, m.writeSeqNo)
98
+	buf.Write(p)
99
+
100
+	checksum := crc32.ChecksumIEEE(buf.Bytes())
101
+	binary.Write(buf, binary.LittleEndian, checksum)
102
+	buf.Write(bytes.Repeat(mtprotoFramePadding, paddingLength/4))
103
+
104
+	m.LogDebug("Write MTProto frame",
105
+		"length", len(p),
106
+		"sequence_number", m.writeSeqNo,
107
+		"crc32", checksum,
108
+		"frame_length", buf.Len(),
109
+	)
110
+	m.writeSeqNo++
111
+
112
+	_, err := m.conn.Write(buf.Bytes())
113
+
114
+	return len(p), err
115
+}
116
+
117
+func (m *MTProtoFrame) LogDebug(msg string, data ...interface{}) {
118
+	m.conn.LogDebug(msg, data...)
119
+}
120
+
121
+func (m *MTProtoFrame) LogInfo(msg string, data ...interface{}) {
122
+	m.conn.LogInfo(msg, data...)
123
+}
124
+
125
+func (m *MTProtoFrame) LogWarn(msg string, data ...interface{}) {
126
+	m.conn.LogWarn(msg, data...)
127
+}
128
+
129
+func (m *MTProtoFrame) LogError(msg string, data ...interface{}) {
130
+	m.conn.LogError(msg, data...)
131
+}
132
+
133
+func (m *MTProtoFrame) LocalAddr() *net.TCPAddr {
134
+	return m.conn.LocalAddr()
135
+}
136
+
137
+func (m *MTProtoFrame) RemoteAddr() *net.TCPAddr {
138
+	return m.conn.RemoteAddr()
139
+}
140
+
141
+func (m *MTProtoFrame) Close() error {
142
+	return m.conn.Close()
143
+}
144
+
145
+func NewMTProtoFrame(conn WrapStreamReadWriteCloser, seqNo int32) WrapPacketReadWriteCloser {
146
+	return &MTProtoFrame{
147
+		conn:       conn,
148
+		readSeqNo:  seqNo,
149
+		writeSeqNo: seqNo,
150
+	}
151
+}

Chargement…
Annuler
Enregistrer