Explorar el Código

Start to rewrite

tags/0.9
9seconds hace 7 años
padre
commit
41b2b1c819

+ 0
- 130
mtproto/wrappers/abridged.go Ver fichero

@@ -1,130 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"bytes"
5
-	"io"
6
-	"net"
7
-
8
-	"github.com/juju/errors"
9
-
10
-	"github.com/9seconds/mtg/mtproto"
11
-	"github.com/9seconds/mtg/wrappers"
12
-)
13
-
14
-type uint24 [3]byte
15
-
16
-const (
17
-	abridgedSmallPacketLength = 0x7f
18
-	abridgedQuickAckLength    = 0x80
19
-	abridgedLargePacketLength = 16777216 // 256 ^ 3
20
-)
21
-
22
-type AbridgedReadWriteCloserWithAddr struct {
23
-	wrappers.BufferedReader
24
-
25
-	conn wrappers.ReadWriteCloserWithAddr
26
-	opts *mtproto.ConnectionOpts
27
-}
28
-
29
-func (a *AbridgedReadWriteCloserWithAddr) Read(p []byte) (int, error) {
30
-	return a.BufferedRead(p, func() error {
31
-		buf := &bytes.Buffer{}
32
-		buf.Grow(3)
33
-
34
-		q := make([]byte, 1)
35
-
36
-		if _, err := io.CopyN(buf, a.conn, 1); err != nil {
37
-			return errors.Annotate(err, "Cannot read message length")
38
-		}
39
-		msgLength := uint8(buf.Bytes()[0])
40
-		q[0] = msgLength
41
-		buf.Reset()
42
-
43
-		if msgLength >= abridgedQuickAckLength {
44
-			a.opts.ReadHacks.QuickAck = true
45
-			msgLength -= abridgedQuickAckLength
46
-		}
47
-
48
-		msgLength32 := uint32(msgLength)
49
-		if msgLength == abridgedSmallPacketLength {
50
-			if _, err := io.CopyN(buf, a.conn, 3); err != nil {
51
-				return errors.Annotate(err, "Cannot read the correct message length")
52
-			}
53
-			number := uint24{}
54
-			copy(number[:], buf.Bytes())
55
-			q = append(q, buf.Bytes()...)
56
-			msgLength32 = fromUint24(number)
57
-		}
58
-		msgLength32 *= 4
59
-
60
-		buf.Reset()
61
-		buf.Grow(int(msgLength32))
62
-
63
-		if _, err := io.CopyN(buf, a.conn, int64(msgLength32)); err != nil {
64
-			return errors.Annotate(err, "Cannot read message")
65
-		}
66
-		q = append(q, buf.Bytes()...)
67
-		a.Buffer.Write(buf.Bytes())
68
-
69
-		return nil
70
-	})
71
-}
72
-
73
-func (a *AbridgedReadWriteCloserWithAddr) Write(p []byte) (int, error) {
74
-	if len(p)%4 != 0 {
75
-		return 0, errors.Errorf("Incorrect packet length %d", len(p))
76
-	}
77
-	if a.opts.WriteHacks.SimpleAck {
78
-		return a.conn.Write(reverseBytes(p))
79
-	}
80
-
81
-	packetLength := len(p) / 4
82
-	switch {
83
-	case packetLength < abridgedSmallPacketLength:
84
-		newData := append([]byte{byte(packetLength)}, p...)
85
-		return a.conn.Write(newData)
86
-
87
-	case packetLength < abridgedLargePacketLength:
88
-		length24 := toUint24(uint32(packetLength))
89
-		buf := &bytes.Buffer{}
90
-		buf.Grow(1 + 3 + len(p))
91
-		buf.WriteByte(byte(abridgedSmallPacketLength))
92
-		buf.Write(length24[:])
93
-		buf.Write(p)
94
-		return a.conn.Write(buf.Bytes())
95
-	}
96
-
97
-	return 0, errors.Errorf("Packet is too big %d", len(p))
98
-}
99
-
100
-func (a *AbridgedReadWriteCloserWithAddr) Close() error {
101
-	return a.conn.Close()
102
-}
103
-
104
-func (a *AbridgedReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
105
-	return a.conn.LocalAddr()
106
-}
107
-
108
-func (a *AbridgedReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
109
-	return a.conn.RemoteAddr()
110
-}
111
-
112
-func (a *AbridgedReadWriteCloserWithAddr) SocketID() string {
113
-	return a.conn.SocketID()
114
-}
115
-
116
-func toUint24(number uint32) uint24 {
117
-	return uint24{byte(number), byte(number >> 8), byte(number >> 16)}
118
-}
119
-
120
-func fromUint24(number uint24) uint32 {
121
-	return uint32(number[0]) + (uint32(number[1]) << 8) + (uint32(number[2]) << 16)
122
-}
123
-
124
-func NewAbridgedRWC(conn wrappers.ReadWriteCloserWithAddr, connOpts *mtproto.ConnectionOpts) wrappers.ReadWriteCloserWithAddr {
125
-	return &AbridgedReadWriteCloserWithAddr{
126
-		BufferedReader: wrappers.NewBufferedReader(),
127
-		conn:           conn,
128
-		opts:           connOpts,
129
-	}
130
-}

+ 1
- 1
utils/read_current_data.go Ver fichero

@@ -2,7 +2,7 @@ package utils
2 2
 
3 3
 import "io"
4 4
 
5
-const readCurrentDataBufferSize = 1024 + 1
5
+const readCurrentDataBufferSize = 1024 + 1 // + 1 because telegram operates with blocks mod 4
6 6
 
7 7
 func ReadCurrentData(src io.Reader) (rv []byte, err error) {
8 8
 	buf := make([]byte, readCurrentDataBufferSize)

+ 14
- 0
utils/reverse_bytes.go Ver fichero

@@ -0,0 +1,14 @@
1
+package utils
2
+
3
+func ReverseBytes(data []byte) []byte {
4
+	dataLen := len(data)
5
+	rv := make([]byte, dataLen)
6
+
7
+	rv[dataLen/2] = data[dataLen/2]
8
+	for i := dataLen/2 - 1; i >= 0; i-- {
9
+		opp := dataLen - i - 1
10
+		rv[i], rv[opp] = data[opp], data[i]
11
+	}
12
+
13
+	return rv
14
+}

+ 11
- 0
utils/uint24.go Ver fichero

@@ -0,0 +1,11 @@
1
+package utils
2
+
3
+type Uint24 [3]byte
4
+
5
+func ToUint24(number uint32) Uint24 {
6
+	return Uint24{byte(number), byte(number >> 8), byte(number >> 16)}
7
+}
8
+
9
+func FromUint24(number Uint24) uint32 {
10
+	return uint32(number[0]) + (uint32(number[1]) << 8) + (uint32(number[2]) << 16)
11
+}

+ 85
- 0
wrappers/blockcipher.go Ver fichero

@@ -0,0 +1,85 @@
1
+package wrappers
2
+
3
+import (
4
+	"crypto/aes"
5
+	"crypto/cipher"
6
+	"net"
7
+
8
+	"github.com/9seconds/mtg/utils"
9
+	"github.com/juju/errors"
10
+)
11
+
12
+type WrapBlockCipher struct {
13
+	BufferedReader
14
+
15
+	conn      WrapStreamReadWriteCloser
16
+	encryptor cipher.BlockMode
17
+	decryptor cipher.BlockMode
18
+}
19
+
20
+func (w *WrapBlockCipher) Read(p []byte) (int, error) {
21
+	return w.BufferedRead(p, func() error {
22
+		var buf []byte
23
+
24
+		for len(buf) == 0 || len(buf)%aes.BlockSize != 0 {
25
+			rv, err := utils.ReadCurrentData(w.conn)
26
+			if err != nil {
27
+				return errors.Annotate(err, "Cannot read from socket")
28
+			}
29
+			buf = append(buf, rv...)
30
+		}
31
+
32
+		w.decryptor.CryptBlocks(buf, buf)
33
+		w.Buffer.Write(buf)
34
+
35
+		return nil
36
+	})
37
+}
38
+
39
+func (w *WrapBlockCipher) Write(p []byte) (int, error) {
40
+	if len(p)%aes.BlockSize > 0 {
41
+		return 0, errors.Errorf("Incorrect block size %d", len(p))
42
+	}
43
+
44
+	encrypted := make([]byte, len(p))
45
+	w.encryptor.CryptBlocks(encrypted, p)
46
+
47
+	return w.conn.Write(encrypted)
48
+}
49
+
50
+func (w *WrapBlockCipher) LogDebug(msg string, data ...interface{}) {
51
+	w.conn.LogDebug(msg, data...)
52
+}
53
+
54
+func (w *WrapBlockCipher) LogInfo(msg string, data ...interface{}) {
55
+	w.conn.LogInfo(msg, data...)
56
+}
57
+
58
+func (w *WrapBlockCipher) LogWarn(msg string, data ...interface{}) {
59
+	w.conn.LogWarn(msg, data...)
60
+}
61
+
62
+func (w *WrapBlockCipher) LogError(msg string, data ...interface{}) {
63
+	w.conn.LogError(msg, data...)
64
+}
65
+
66
+func (w *WrapBlockCipher) LocalAddr() *net.TCPAddr {
67
+	return w.conn.LocalAddr()
68
+}
69
+
70
+func (w *WrapBlockCipher) RemoteAddr() *net.TCPAddr {
71
+	return w.conn.RemoteAddr()
72
+}
73
+
74
+func (w *WrapBlockCipher) Close() error {
75
+	return w.conn.Close()
76
+}
77
+
78
+func NewWrapBlockCipher(conn WrapStreamReadWriteCloser, encryptor, decryptor cipher.BlockMode) WrapStreamReadWriteCloser {
79
+	return &WrapBlockCipher{
80
+		BufferedReader: NewBufferedReader(),
81
+		conn:           conn,
82
+		encryptor:      encryptor,
83
+		decryptor:      decryptor,
84
+	}
85
+}

+ 0
- 76
wrappers/blockcipherrwc.go Ver fichero

@@ -1,76 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"crypto/aes"
5
-	"crypto/cipher"
6
-	"fmt"
7
-	"net"
8
-
9
-	"github.com/juju/errors"
10
-
11
-	"github.com/9seconds/mtg/utils"
12
-)
13
-
14
-type BlockCipherReadWriteCloserWithAddr struct {
15
-	BufferedReader
16
-
17
-	conn      ReadWriteCloserWithAddr
18
-	encryptor cipher.BlockMode
19
-	decryptor cipher.BlockMode
20
-}
21
-
22
-func (c *BlockCipherReadWriteCloserWithAddr) Read(p []byte) (int, error) {
23
-	return c.BufferedRead(p, func() error {
24
-		var buf []byte
25
-
26
-		for len(buf) == 0 || len(buf)%aes.BlockSize != 0 {
27
-			rv, err := utils.ReadCurrentData(c.conn)
28
-			if err != nil {
29
-				return errors.Annotate(err, "Cannot read from socket")
30
-			}
31
-			buf = append(buf, rv...)
32
-		}
33
-
34
-		c.decryptor.CryptBlocks(buf, buf)
35
-		c.Buffer.Write(buf)
36
-
37
-		return nil
38
-	})
39
-}
40
-
41
-func (c *BlockCipherReadWriteCloserWithAddr) Write(p []byte) (int, error) {
42
-	if len(p)%aes.BlockSize > 0 {
43
-		return 0, errors.Errorf("Incorrect block size %d", len(p))
44
-	}
45
-
46
-	encrypted := make([]byte, len(p))
47
-	c.encryptor.CryptBlocks(encrypted, p)
48
-
49
-	return c.conn.Write(encrypted)
50
-}
51
-
52
-func (c *BlockCipherReadWriteCloserWithAddr) Close() error {
53
-	fmt.Println("BlockCipherReadWriteCloserWithAddr closes", "sockid", c.SocketID(), "bufsize", c.Buffer.Len())
54
-	return c.conn.Close()
55
-}
56
-
57
-func (c *BlockCipherReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
58
-	return c.conn.LocalAddr()
59
-}
60
-
61
-func (c *BlockCipherReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
62
-	return c.conn.RemoteAddr()
63
-}
64
-
65
-func (c *BlockCipherReadWriteCloserWithAddr) SocketID() string {
66
-	return c.conn.SocketID()
67
-}
68
-
69
-func NewBlockCipherRWC(conn ReadWriteCloserWithAddr, encryptor, decryptor cipher.BlockMode) ReadWriteCloserWithAddr {
70
-	return &BlockCipherReadWriteCloserWithAddr{
71
-		BufferedReader: NewBufferedReader(),
72
-		conn:           conn,
73
-		encryptor:      encryptor,
74
-		decryptor:      decryptor,
75
-	}
76
-}

+ 0
- 27
wrappers/buffer_pool.go Ver fichero

@@ -1,27 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"bytes"
5
-	"sync"
6
-)
7
-
8
-var bufPool sync.Pool
9
-
10
-func getBuffer() *bytes.Buffer {
11
-	buf := bufPool.Get().(*bytes.Buffer)
12
-	buf.Reset()
13
-
14
-	return buf
15
-}
16
-
17
-func putBuffer(buf *bytes.Buffer) {
18
-	bufPool.Put(buf)
19
-}
20
-
21
-func init() {
22
-	bufPool = sync.Pool{
23
-		New: func() interface{} {
24
-			return &bytes.Buffer{}
25
-		},
26
-	}
27
-}

+ 1
- 9
wrappers/buffered_reader.go Ver fichero

@@ -1,19 +1,11 @@
1 1
 package wrappers
2 2
 
3
-import (
4
-	"bytes"
5
-
6
-	"github.com/juju/errors"
7
-)
3
+import "bytes"
8 4
 
9 5
 type BufferedReader struct {
10 6
 	Buffer *bytes.Buffer
11 7
 }
12 8
 
13
-var (
14
-	BufferedReaderContinue = errors.New("Please continue reading")
15
-)
16
-
17 9
 func (b *BufferedReader) BufferedRead(p []byte, callback func() error) (int, error) {
18 10
 	if b.Buffer.Len() > 0 {
19 11
 		return b.flush(p)

+ 115
- 0
wrappers/conn.go Ver fichero

@@ -0,0 +1,115 @@
1
+package wrappers
2
+
3
+import (
4
+	"net"
5
+	"time"
6
+
7
+	"go.uber.org/zap"
8
+)
9
+
10
+type ConnPurpose uint8
11
+
12
+func (c ConnPurpose) String() string {
13
+	switch c {
14
+	case ConnPurposeClient:
15
+		return "client"
16
+	case ConnPurposeTelegram:
17
+		return "telegram"
18
+	}
19
+
20
+	return ""
21
+}
22
+
23
+const (
24
+	ConnPurposeClient = iota
25
+	ConnPurposeTelegram
26
+)
27
+
28
+const (
29
+	connTimeoutRead  = 5 * time.Minute
30
+	connTimeoutWrite = 5 * time.Minute
31
+)
32
+
33
+type WrapConn struct {
34
+	purpose    ConnPurpose
35
+	connID     string
36
+	conn       net.Conn
37
+	logger     *zap.SugaredLogger
38
+	publicIPv4 net.IP
39
+	publicIPv6 net.IP
40
+}
41
+
42
+func (w *WrapConn) Write(p []byte) (int, error) {
43
+	w.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite))
44
+	n, err := w.conn.Write(p)
45
+
46
+	w.logger.Debugw("Write to stream", "bytes", n, "error", err)
47
+
48
+	return n, err
49
+}
50
+
51
+func (w *WrapConn) Read(p []byte) (int, error) {
52
+	w.conn.SetReadDeadline(time.Now().Add(connTimeoutRead))
53
+	n, err := w.conn.Read(p)
54
+
55
+	w.logger.Debugw("Read from stream", "bytes", n, "error", err)
56
+
57
+	return n, err
58
+}
59
+
60
+func (w *WrapConn) Close() error {
61
+	defer w.LogDebug("Closed connection")
62
+	return w.conn.Close()
63
+}
64
+
65
+func (w *WrapConn) LocalAddr() *net.TCPAddr {
66
+	addr := w.conn.LocalAddr().(*net.TCPAddr)
67
+	newAddr := *addr
68
+
69
+	if w.RemoteAddr().IP.To4() != nil {
70
+		if w.publicIPv4 != nil {
71
+			newAddr.IP = w.publicIPv4
72
+		}
73
+	} else if w.publicIPv6 != nil {
74
+		newAddr.IP = w.publicIPv6
75
+	}
76
+
77
+	return &newAddr
78
+}
79
+
80
+func (w *WrapConn) RemoteAddr() *net.TCPAddr {
81
+	return w.conn.RemoteAddr().(*net.TCPAddr)
82
+}
83
+
84
+func (w *WrapConn) LogDebug(msg string, data ...interface{}) {
85
+	w.logger.Debugw(msg, data...)
86
+}
87
+
88
+func (w *WrapConn) LogInfo(msg string, data ...interface{}) {
89
+	w.logger.Infow(msg, data...)
90
+}
91
+
92
+func (w *WrapConn) LogWarn(msg string, data ...interface{}) {
93
+	w.logger.Warnw(msg, data...)
94
+}
95
+
96
+func (w *WrapConn) LogError(msg string, data ...interface{}) {
97
+	w.logger.Errorw(msg, data...)
98
+}
99
+
100
+func NewConn(connID string, purpose ConnPurpose, conn net.Conn, publicIPv4, publicIPv6 net.IP) WrapStreamReadWriteCloser {
101
+	logger := zap.S().With(
102
+		"connection_id", connID,
103
+		"local_address", conn.LocalAddr(),
104
+		"remote_address", conn.RemoteAddr(),
105
+	)
106
+
107
+	return &WrapConn{
108
+		logger:     logger,
109
+		purpose:    purpose,
110
+		connID:     connID,
111
+		conn:       conn,
112
+		publicIPv4: publicIPv4,
113
+		publicIPv6: publicIPv6,
114
+	}
115
+}

+ 76
- 0
wrappers/ctx.go Ver fichero

@@ -0,0 +1,76 @@
1
+package wrappers
2
+
3
+import (
4
+	"context"
5
+	"net"
6
+
7
+	"github.com/juju/errors"
8
+)
9
+
10
+type WrapCtx struct {
11
+	cancel context.CancelFunc
12
+	conn   WrapStreamReadWriteCloser
13
+	ctx    context.Context
14
+}
15
+
16
+func (w *WrapCtx) Read(p []byte) (int, error) {
17
+	select {
18
+	case <-w.ctx.Done():
19
+		return 0, errors.Annotate(w.ctx.Err(), "Read is failed because of closed context")
20
+	default:
21
+		n, err := w.conn.Read(p)
22
+		if err != nil {
23
+			w.cancel()
24
+		}
25
+		return n, err
26
+	}
27
+}
28
+
29
+func (w *WrapCtx) Write(p []byte) (int, error) {
30
+	select {
31
+	case <-w.ctx.Done():
32
+		return 0, errors.Annotate(w.ctx.Err(), "Write is failed because of closed context")
33
+	default:
34
+		n, err := w.conn.Write(p)
35
+		if err != nil {
36
+			w.cancel()
37
+		}
38
+		return n, err
39
+	}
40
+}
41
+
42
+func (w *WrapCtx) LogDebug(msg string, data ...interface{}) {
43
+	w.conn.LogDebug(msg, data...)
44
+}
45
+
46
+func (w *WrapCtx) LogInfo(msg string, data ...interface{}) {
47
+	w.conn.LogInfo(msg, data...)
48
+}
49
+
50
+func (w *WrapCtx) LogWarn(msg string, data ...interface{}) {
51
+	w.conn.LogWarn(msg, data...)
52
+}
53
+
54
+func (w *WrapCtx) LogError(msg string, data ...interface{}) {
55
+	w.conn.LogError(msg, data...)
56
+}
57
+
58
+func (w *WrapCtx) LocalAddr() *net.TCPAddr {
59
+	return w.conn.LocalAddr()
60
+}
61
+
62
+func (w *WrapCtx) RemoteAddr() *net.TCPAddr {
63
+	return w.conn.RemoteAddr()
64
+}
65
+
66
+func (w *WrapCtx) Close() error {
67
+	return w.conn.Close()
68
+}
69
+
70
+func NewCtx(ctx context.Context, cancel context.CancelFunc, conn WrapStreamReadWriteCloser) WrapStreamReadWriteCloser {
71
+	return &WrapCtx{
72
+		ctx:    ctx,
73
+		cancel: cancel,
74
+		conn:   conn,
75
+	}
76
+}

+ 0
- 71
wrappers/ctxrwc.go Ver fichero

@@ -1,71 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"context"
5
-	"net"
6
-
7
-	"github.com/juju/errors"
8
-)
9
-
10
-// CtxReadWriteCloser wraps underlying connection and does management of the
11
-// context and its cancel function.
12
-type CtxReadWriteCloserWithAddr struct {
13
-	ctx    context.Context
14
-	conn   ReadWriteCloserWithAddr
15
-	cancel context.CancelFunc
16
-}
17
-
18
-// Read reads from connection
19
-func (c *CtxReadWriteCloserWithAddr) Read(p []byte) (int, error) {
20
-	select {
21
-	case <-c.ctx.Done():
22
-		return 0, errors.Annotate(c.ctx.Err(), "Read is failed because of closed context")
23
-	default:
24
-		n, err := c.conn.Read(p)
25
-		if err != nil {
26
-			c.cancel()
27
-		}
28
-		return n, err
29
-	}
30
-}
31
-
32
-// Write writes into connection.
33
-func (c *CtxReadWriteCloserWithAddr) Write(p []byte) (int, error) {
34
-	select {
35
-	case <-c.ctx.Done():
36
-		return 0, errors.Annotate(c.ctx.Err(), "Write is failed because of closed context")
37
-	default:
38
-		n, err := c.conn.Write(p)
39
-		if err != nil {
40
-			c.cancel()
41
-		}
42
-		return n, err
43
-	}
44
-}
45
-
46
-// Close closes underlying connection.
47
-func (c *CtxReadWriteCloserWithAddr) Close() error {
48
-	return c.conn.Close()
49
-}
50
-
51
-func (c *CtxReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
52
-	return c.conn.LocalAddr()
53
-}
54
-
55
-func (c *CtxReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
56
-	return c.conn.RemoteAddr()
57
-}
58
-
59
-func (c *CtxReadWriteCloserWithAddr) SocketID() string {
60
-	return c.conn.SocketID()
61
-}
62
-
63
-// NewCtxRWC returns ReadWriteCloser which respects given context,
64
-// cancellation etc.
65
-func NewCtxRWC(ctx context.Context, cancel context.CancelFunc, conn ReadWriteCloserWithAddr) ReadWriteCloserWithAddr {
66
-	return &CtxReadWriteCloserWithAddr{
67
-		conn:   conn,
68
-		ctx:    ctx,
69
-		cancel: cancel,
70
-	}
71
-}

+ 0
- 59
wrappers/logrwc.go Ver fichero

@@ -1,59 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"net"
5
-
6
-	"go.uber.org/zap"
7
-)
8
-
9
-// LogReadWriteCloser adds additional logging for reading/writing. All
10
-// logging is performed for debug mode only.
11
-type LogReadWriteCloserWithAddr struct {
12
-	conn   ReadWriteCloserWithAddr
13
-	logger *zap.SugaredLogger
14
-	sockid string
15
-	name   string
16
-}
17
-
18
-// Read reads from connection
19
-func (l *LogReadWriteCloserWithAddr) Read(p []byte) (n int, err error) {
20
-	n, err = l.conn.Read(p)
21
-	l.logger.Debugw("Finish reading", "name", l.name, "socketid", l.sockid, "nbytes", n, "error", err, "localAddr", l.LocalAddr())
22
-	return
23
-}
24
-
25
-// Write writes into connection.
26
-func (l *LogReadWriteCloserWithAddr) Write(p []byte) (n int, err error) {
27
-	n, err = l.conn.Write(p)
28
-	l.logger.Debugw("Finish writing", "name", l.name, "socketid", l.sockid, "nbytes", n, "error", err, "localAddr", l.LocalAddr())
29
-	return
30
-}
31
-
32
-// Close closes underlying connection.
33
-func (l *LogReadWriteCloserWithAddr) Close() error {
34
-	err := l.conn.Close()
35
-	l.logger.Debugw("Finish closing socket", "name", l.name, "socketid", l.sockid, "error", err)
36
-	return err
37
-}
38
-
39
-func (l *LogReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
40
-	return l.conn.LocalAddr()
41
-}
42
-
43
-func (l *LogReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
44
-	return l.conn.RemoteAddr()
45
-}
46
-
47
-func (l *LogReadWriteCloserWithAddr) SocketID() string {
48
-	return l.sockid
49
-}
50
-
51
-// NewLogRWC wraps ReadWriteCloser with logger calls.
52
-func NewLogRWC(conn ReadWriteCloserWithAddr, logger *zap.SugaredLogger, sockid string, name string) ReadWriteCloserWithAddr {
53
-	return &LogReadWriteCloserWithAddr{
54
-		conn:   conn,
55
-		logger: logger,
56
-		sockid: sockid,
57
-		name:   name,
58
-	}
59
-}

+ 155
- 0
wrappers/mtproto_abridged.go Ver fichero

@@ -0,0 +1,155 @@
1
+package wrappers
2
+
3
+import (
4
+	"bytes"
5
+	"io"
6
+	"net"
7
+
8
+	"github.com/juju/errors"
9
+
10
+	"github.com/9seconds/mtg/mtproto"
11
+	"github.com/9seconds/mtg/utils"
12
+)
13
+
14
+const (
15
+	abridgedSmallPacketLength = 0x7f
16
+	abridgedQuickAckLength    = 0x80
17
+	abridgedLargePacketLength = 16777216 // 256 ^ 3
18
+)
19
+
20
+type MTProtoAbridged struct {
21
+	conn WrapStreamReadWriteCloser
22
+	opts *mtproto.ConnectionOpts
23
+
24
+	readCounter  uint32
25
+	writeCounter uint32
26
+}
27
+
28
+func (m *MTProtoAbridged) Read() ([]byte, error) {
29
+	m.LogDebug("Read abridged packet",
30
+		"simple_ack", m.opts.WriteHacks.SimpleAck,
31
+		"quick_ack", m.opts.WriteHacks.QuickAck,
32
+		"counter", m.readCounter,
33
+	)
34
+
35
+	buf := &bytes.Buffer{}
36
+	buf.Grow(3)
37
+
38
+	if _, err := io.CopyN(buf, m.conn, 1); err != nil {
39
+		return nil, errors.Annotate(err, "Cannot read message length")
40
+	}
41
+	msgLength := uint8(buf.Bytes()[0])
42
+	buf.Reset()
43
+
44
+	m.LogDebug("Abridged packet first byte",
45
+		"byte", msgLength,
46
+		"counter", m.readCounter,
47
+	)
48
+
49
+	if msgLength >= abridgedQuickAckLength {
50
+		m.opts.ReadHacks.QuickAck = true
51
+		msgLength -= abridgedQuickAckLength
52
+	}
53
+
54
+	msgLength32 := uint32(msgLength)
55
+	if msgLength == abridgedSmallPacketLength {
56
+		if _, err := io.CopyN(buf, m.conn, 3); err != nil {
57
+			return nil, errors.Annotate(err, "Cannot read the correct message length")
58
+		}
59
+		number := utils.Uint24{}
60
+		copy(number[:], buf.Bytes())
61
+		msgLength32 = utils.FromUint24(number)
62
+	}
63
+	msgLength32 *= 4
64
+
65
+	m.LogDebug("Abridged packet length",
66
+		"length", msgLength32,
67
+		"counter", m.readCounter,
68
+	)
69
+
70
+	buf.Reset()
71
+	buf.Grow(int(msgLength32))
72
+
73
+	if _, err := io.CopyN(buf, m.conn, int64(msgLength32)); err != nil {
74
+		return nil, errors.Annotate(err, "Cannot read message")
75
+	}
76
+	m.readCounter++
77
+
78
+	return buf.Bytes(), nil
79
+}
80
+
81
+func (m *MTProtoAbridged) Write(p []byte) (int, error) {
82
+	m.LogDebug("Write abridged packet",
83
+		"length", len(p),
84
+		"simple_ack", m.opts.WriteHacks.SimpleAck,
85
+		"quick_ack", m.opts.WriteHacks.QuickAck,
86
+		"counter", m.writeCounter,
87
+	)
88
+
89
+	if len(p)%4 == 0 {
90
+		return 0, errors.Errorf("Incorrect packet length %d", len(p))
91
+	}
92
+
93
+	if m.opts.WriteHacks.SimpleAck {
94
+		return m.conn.Write(utils.ReverseBytes(p))
95
+	}
96
+
97
+	packetLength := len(p) / 4
98
+	switch {
99
+	case packetLength < abridgedSmallPacketLength:
100
+		newData := append([]byte{byte(packetLength)}, p...)
101
+
102
+		m.writeCounter++
103
+		return m.conn.Write(newData)
104
+
105
+	case packetLength < abridgedLargePacketLength:
106
+		length24 := utils.ToUint24(uint32(packetLength))
107
+
108
+		buf := &bytes.Buffer{}
109
+		buf.Grow(1 + 3 + len(p))
110
+
111
+		buf.WriteByte(byte(abridgedSmallPacketLength))
112
+		buf.Write(length24[:])
113
+		buf.Write(p)
114
+
115
+		m.writeCounter++
116
+		return m.conn.Write(buf.Bytes())
117
+	}
118
+
119
+	return 0, errors.Errorf("Packet is too big %d", len(p))
120
+}
121
+
122
+func (m *MTProtoAbridged) LogDebug(msg string, data ...interface{}) {
123
+	m.conn.LogDebug(msg, data...)
124
+}
125
+
126
+func (m *MTProtoAbridged) LogInfo(msg string, data ...interface{}) {
127
+	m.conn.LogInfo(msg, data...)
128
+}
129
+
130
+func (m *MTProtoAbridged) LogWarn(msg string, data ...interface{}) {
131
+	m.conn.LogWarn(msg, data...)
132
+}
133
+
134
+func (m *MTProtoAbridged) LogError(msg string, data ...interface{}) {
135
+	m.conn.LogError(msg, data...)
136
+}
137
+
138
+func (m *MTProtoAbridged) LocalAddr() *net.TCPAddr {
139
+	return m.conn.LocalAddr()
140
+}
141
+
142
+func (m *MTProtoAbridged) RemoteAddr() *net.TCPAddr {
143
+	return m.conn.RemoteAddr()
144
+}
145
+
146
+func (m *MTProtoAbridged) Close() error {
147
+	return m.conn.Close()
148
+}
149
+
150
+func NewMTProtoAbridged(conn WrapStreamReadWriteCloser, opts *mtproto.ConnectionOpts) WrapPacketReadWriteCloser {
151
+	return &MTProtoAbridged{
152
+		conn: conn,
153
+		opts: opts,
154
+	}
155
+}

+ 0
- 14
wrappers/rwcaddr.go Ver fichero

@@ -1,14 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"io"
5
-	"net"
6
-)
7
-
8
-type ReadWriteCloserWithAddr interface {
9
-	io.ReadWriteCloser
10
-
11
-	LocalAddr() *net.TCPAddr
12
-	RemoteAddr() *net.TCPAddr
13
-	SocketID() string
14
-}

+ 67
- 0
wrappers/streamcipher.go Ver fichero

@@ -0,0 +1,67 @@
1
+package wrappers
2
+
3
+import (
4
+	"crypto/cipher"
5
+	"net"
6
+
7
+	"github.com/juju/errors"
8
+)
9
+
10
+type WrapStreamCipher struct {
11
+	encryptor cipher.Stream
12
+	decryptor cipher.Stream
13
+	conn      WrapStreamReadWriteCloser
14
+}
15
+
16
+func (w *WrapStreamCipher) Read(p []byte) (int, error) {
17
+	n, err := w.conn.Read(p)
18
+	if err != nil {
19
+		return 0, errors.Annotate(err, "Cannot read stream ciphered data")
20
+	}
21
+	w.decryptor.XORKeyStream(p, p[:n])
22
+
23
+	return n, nil
24
+}
25
+
26
+func (w *WrapStreamCipher) Write(p []byte) (int, error) {
27
+	encrypted := make([]byte, len(p))
28
+	w.encryptor.XORKeyStream(encrypted, p)
29
+
30
+	return w.conn.Write(encrypted)
31
+}
32
+
33
+func (w *WrapStreamCipher) LogDebug(msg string, data ...interface{}) {
34
+	w.conn.LogDebug(msg, data...)
35
+}
36
+
37
+func (w *WrapStreamCipher) LogInfo(msg string, data ...interface{}) {
38
+	w.conn.LogInfo(msg, data...)
39
+}
40
+
41
+func (w *WrapStreamCipher) LogWarn(msg string, data ...interface{}) {
42
+	w.conn.LogWarn(msg, data...)
43
+}
44
+
45
+func (w *WrapStreamCipher) LogError(msg string, data ...interface{}) {
46
+	w.conn.LogError(msg, data...)
47
+}
48
+
49
+func (w *WrapStreamCipher) LocalAddr() *net.TCPAddr {
50
+	return w.conn.LocalAddr()
51
+}
52
+
53
+func (w *WrapStreamCipher) RemoteAddr() *net.TCPAddr {
54
+	return w.conn.RemoteAddr()
55
+}
56
+
57
+func (w *WrapStreamCipher) Close() error {
58
+	return w.conn.Close()
59
+}
60
+
61
+func NewStreamCipher(conn WrapStreamReadWriteCloser, encryptor, decryptor cipher.Stream) WrapStreamReadWriteCloser {
62
+	return &WrapStreamCipher{
63
+		conn:      conn,
64
+		encryptor: encryptor,
65
+		decryptor: decryptor,
66
+	}
67
+}

+ 0
- 66
wrappers/streamcipherrwc.go Ver fichero

@@ -1,66 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"crypto/cipher"
5
-	"net"
6
-)
7
-
8
-// StreamCipherReadWriteCloser is a ReadWriteCloser which ciphers
9
-// incoming and outgoing data with givem cipher.Stream instances.
10
-type StreamCipherReadWriteCloserWithAddr struct {
11
-	encryptor cipher.Stream
12
-	decryptor cipher.Stream
13
-	conn      ReadWriteCloserWithAddr
14
-}
15
-
16
-// Read reads from connection
17
-func (c *StreamCipherReadWriteCloserWithAddr) Read(p []byte) (n int, err error) {
18
-	n, err = c.conn.Read(p)
19
-	c.decryptor.XORKeyStream(p, p[:n])
20
-	return
21
-}
22
-
23
-// Write writes into connection.
24
-func (c *StreamCipherReadWriteCloserWithAddr) Write(p []byte) (int, error) {
25
-	// This is to decrease an amount of allocations. Unfortunately, escape
26
-	// analysis in (at least Golang 1.10) is absolutely not perfect. For
27
-	// example, it understands that we want to have a slice locally, right?
28
-	// But since slice is effectively 2 ints + uintptr to [number]byte, the
29
-	// most heavyweight part is placed in heap.
30
-	buf := getBuffer()
31
-	defer putBuffer(buf)
32
-	buf.Grow(len(p))
33
-	buf.Write(p)
34
-
35
-	encrypted := buf.Bytes()
36
-	c.encryptor.XORKeyStream(encrypted, p)
37
-
38
-	return c.conn.Write(encrypted)
39
-}
40
-
41
-// Close closes underlying connection.
42
-func (c *StreamCipherReadWriteCloserWithAddr) Close() error {
43
-	return c.conn.Close()
44
-}
45
-
46
-func (c *StreamCipherReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
47
-	return c.conn.LocalAddr()
48
-}
49
-
50
-func (c *StreamCipherReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
51
-	return c.conn.RemoteAddr()
52
-}
53
-
54
-func (c *StreamCipherReadWriteCloserWithAddr) SocketID() string {
55
-	return c.conn.SocketID()
56
-}
57
-
58
-// NewStreamCipherRWC returns wrapper which transparently
59
-// encrypts/decrypts traffic with obfuscated2 protocol.
60
-func NewStreamCipherRWC(conn ReadWriteCloserWithAddr, encryptor, decryptor cipher.Stream) ReadWriteCloserWithAddr {
61
-	return &StreamCipherReadWriteCloserWithAddr{
62
-		conn:      conn,
63
-		encryptor: encryptor,
64
-		decryptor: decryptor,
65
-	}
66
-}

+ 0
- 61
wrappers/timeoutrwc.go Ver fichero

@@ -1,61 +0,0 @@
1
-package wrappers
2
-
3
-import (
4
-	"net"
5
-	"time"
6
-
7
-	"github.com/9seconds/mtg/config"
8
-)
9
-
10
-type TimeoutReadWriteCloserWithAddr struct {
11
-	conn       net.Conn
12
-	sock       string
13
-	publicIPv4 net.IP
14
-	publicIPv6 net.IP
15
-}
16
-
17
-func (t *TimeoutReadWriteCloserWithAddr) Read(p []byte) (int, error) {
18
-	t.conn.SetReadDeadline(time.Now().Add(config.TimeoutRead))
19
-	return t.conn.Read(p)
20
-}
21
-
22
-func (t *TimeoutReadWriteCloserWithAddr) Write(p []byte) (int, error) {
23
-	t.conn.SetWriteDeadline(time.Now().Add(config.TimeoutWrite))
24
-	return t.conn.Write(p)
25
-}
26
-
27
-func (t *TimeoutReadWriteCloserWithAddr) Close() error {
28
-	return t.conn.Close()
29
-}
30
-
31
-func (t *TimeoutReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
32
-	return t.conn.RemoteAddr().(*net.TCPAddr)
33
-}
34
-
35
-func (t *TimeoutReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
36
-	addr := t.conn.LocalAddr().(*net.TCPAddr)
37
-	newAddr := *addr
38
-
39
-	if t.RemoteAddr().IP.To4() != nil {
40
-		if t.publicIPv4 != nil {
41
-			newAddr.IP = t.publicIPv4
42
-		}
43
-	} else if t.publicIPv6 != nil {
44
-		newAddr.IP = t.publicIPv6
45
-	}
46
-
47
-	return &newAddr
48
-}
49
-
50
-func (t *TimeoutReadWriteCloserWithAddr) SocketID() string {
51
-	return t.sock
52
-}
53
-
54
-func NewTimeoutRWC(conn net.Conn, sock string, ipv4, ipv6 net.IP) ReadWriteCloserWithAddr {
55
-	return &TimeoutReadWriteCloserWithAddr{
56
-		conn:       conn,
57
-		publicIPv4: ipv4,
58
-		publicIPv6: ipv6,
59
-		sock:       sock,
60
-	}
61
-}

+ 0
- 51
wrappers/trafficrwc.go Ver fichero

@@ -1,51 +0,0 @@
1
-package wrappers
2
-
3
-import "net"
4
-
5
-// TrafficReadWriteCloser counts an amount of ingress/egress traffic by
6
-// calling given callbacks.
7
-type TrafficReadWriteCloserWithAddr struct {
8
-	conn          ReadWriteCloserWithAddr
9
-	readCallback  func(int)
10
-	writeCallback func(int)
11
-}
12
-
13
-// Read reads from connection
14
-func (t *TrafficReadWriteCloserWithAddr) Read(p []byte) (n int, err error) {
15
-	n, err = t.conn.Read(p)
16
-	t.readCallback(n)
17
-	return
18
-}
19
-
20
-// Write writes into connection.
21
-func (t *TrafficReadWriteCloserWithAddr) Write(p []byte) (n int, err error) {
22
-	n, err = t.conn.Write(p)
23
-	t.writeCallback(n)
24
-	return
25
-}
26
-
27
-// Close closes underlying connection.
28
-func (t *TrafficReadWriteCloserWithAddr) Close() error {
29
-	return t.conn.Close()
30
-}
31
-
32
-func (t *TrafficReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
33
-	return t.conn.LocalAddr()
34
-}
35
-
36
-func (t *TrafficReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
37
-	return t.conn.RemoteAddr()
38
-}
39
-
40
-func (t *TrafficReadWriteCloserWithAddr) SocketID() string {
41
-	return t.conn.SocketID()
42
-}
43
-
44
-// NewTrafficRWC wraps ReadWriteCloser to have read/write callbacks.
45
-func NewTrafficRWC(conn ReadWriteCloserWithAddr, readCallback, writeCallback func(int)) ReadWriteCloserWithAddr {
46
-	return &TrafficReadWriteCloserWithAddr{
47
-		conn:          conn,
48
-		readCallback:  readCallback,
49
-		writeCallback: writeCallback,
50
-	}
51
-}

+ 78
- 0
wrappers/wrap.go Ver fichero

@@ -0,0 +1,78 @@
1
+package wrappers
2
+
3
+import (
4
+	"io"
5
+	"net"
6
+)
7
+
8
+type Wrap interface {
9
+	LogDebug(msg string, data ...interface{})
10
+	LogInfo(msg string, data ...interface{})
11
+	LogWarn(msg string, data ...interface{})
12
+	LogError(msg string, data ...interface{})
13
+
14
+	LocalAddr() *net.TCPAddr
15
+	RemoteAddr() *net.TCPAddr
16
+}
17
+
18
+type WrapWriter interface {
19
+	io.Writer
20
+	Wrap
21
+}
22
+
23
+type WrapWriteCloser interface {
24
+	io.Closer
25
+	WrapWriter
26
+}
27
+
28
+type WrapStreamReader interface {
29
+	io.Reader
30
+	Wrap
31
+}
32
+
33
+type WrapStreamReadCloser interface {
34
+	io.Closer
35
+	WrapStreamReader
36
+}
37
+
38
+type WrapStreamReadWriter interface {
39
+	io.Writer
40
+	WrapStreamReader
41
+}
42
+
43
+type WrapStreamWriteCloser interface {
44
+	io.Closer
45
+	io.Writer
46
+	Wrap
47
+}
48
+
49
+type WrapStreamReadWriteCloser interface {
50
+	io.Closer
51
+	WrapStreamReadWriter
52
+}
53
+
54
+type WrapPacketReader interface {
55
+	Read() ([]byte, error)
56
+	Wrap
57
+}
58
+
59
+type WrapPacketReadWriter interface {
60
+	io.Writer
61
+	WrapPacketReader
62
+}
63
+
64
+type WrapBlockReadCloser interface {
65
+	io.Closer
66
+	WrapPacketReader
67
+}
68
+
69
+type WrapPacketWriteCloser interface {
70
+	io.Writer
71
+	io.Closer
72
+	Wrap
73
+}
74
+
75
+type WrapPacketReadWriteCloser interface {
76
+	io.Closer
77
+	WrapPacketReadWriter
78
+}

Loading…
Cancelar
Guardar