Browse Source

wip

tags/1.0^2
9seconds 6 years ago
parent
commit
c9743b5675

+ 3
- 0
conntypes/packet.go View File

@@ -0,0 +1,3 @@
1
+package conntypes
2
+
3
+type Packet []byte

wrappers/interfaces.go → conntypes/wrappers.go View File

@@ -1,4 +1,4 @@
1
-package wrappers
1
+package conntypes
2 2
 
3 3
 import (
4 4
 	"io"
@@ -8,8 +8,6 @@ import (
8 8
 	"go.uber.org/zap"
9 9
 )
10 10
 
11
-type Packet []byte
12
-
13 11
 // Wrap is a base interface for all wrappers in this package.
14 12
 type Wrap interface {
15 13
 	Conn() net.Conn

+ 2
- 0
go.mod View File

@@ -9,6 +9,8 @@ require (
9 9
 	github.com/allegro/bigcache v1.2.1
10 10
 	github.com/beevik/ntp v0.2.0
11 11
 	github.com/cespare/xxhash v1.1.0
12
+	github.com/dustin/go-humanize v1.0.0
13
+	github.com/gammazero/deque v0.0.0-20190521012701-46e4ffb7a622
12 14
 	github.com/juju/errors v0.0.0-20190806202954-0232dcc7464d
13 15
 	github.com/kr/pretty v0.1.0 // indirect
14 16
 	github.com/pkg/errors v0.8.1

+ 4
- 0
go.sum View File

@@ -24,6 +24,10 @@ github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghf
24 24
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
25 25
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
26 26
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
27
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
28
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
29
+github.com/gammazero/deque v0.0.0-20190521012701-46e4ffb7a622 h1:lxbhOGZ9pU3Kf8P6lFluUcE82yVZn2EqEf4+mWRNPV0=
30
+github.com/gammazero/deque v0.0.0-20190521012701-46e4ffb7a622/go.mod h1:D90+MBHVc9Sk1lJAbEVgws0eYEurY4mv2TDso3Nxh3w=
27 31
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
28 32
 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
29 33
 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=

+ 59
- 0
hub/closeable_channel.go View File

@@ -0,0 +1,59 @@
1
+package hub
2
+
3
+import (
4
+	"context"
5
+	"errors"
6
+	"time"
7
+
8
+	"github.com/9seconds/mtg/conntypes"
9
+)
10
+
11
+const closeableChannelReadTimeout = 2 * time.Minute
12
+
13
+type ChannelReadCloser interface {
14
+	Read() (conntypes.Packet, error)
15
+	Close()
16
+}
17
+
18
+type closeableChannel struct {
19
+	channel chan conntypes.Packet
20
+	ctx     context.Context
21
+	cancel  context.CancelFunc
22
+}
23
+
24
+func (c *closeableChannel) Read() (conntypes.Packet, error) {
25
+	timer := time.NewTimer(closeableChannelReadTimeout)
26
+	defer timer.Stop()
27
+
28
+	select {
29
+	case <-timer.C:
30
+		return nil, errors.New("timeout")
31
+	case <-c.ctx.Done():
32
+		return nil, errors.New("channel was closed")
33
+	case packet := <-c.channel:
34
+		return packet, nil
35
+	}
36
+}
37
+
38
+func (c *closeableChannel) write(packet conntypes.Packet) error {
39
+	select {
40
+	case <-c.ctx.Done():
41
+		return errors.New("channel was closed")
42
+	case c.channel <- packet:
43
+		return nil
44
+	}
45
+}
46
+
47
+func (c *closeableChannel) Close() {
48
+	c.cancel()
49
+	c.channel = nil
50
+}
51
+
52
+func newCloseableChannel(ctx context.Context) *closeableChannel {
53
+	ctx, cancel := context.WithCancel(ctx)
54
+	return &closeableChannel{
55
+		channel: make(chan conntypes.Packet),
56
+		ctx:     ctx,
57
+		cancel:  cancel,
58
+	}
59
+}

+ 100
- 0
hub/connection.go View File

@@ -0,0 +1,100 @@
1
+package hub
2
+
3
+import (
4
+	"fmt"
5
+	"math/rand"
6
+	"sync"
7
+
8
+	"github.com/9seconds/mtg/conntypes"
9
+	"github.com/9seconds/mtg/mtproto"
10
+	"github.com/9seconds/mtg/protocol"
11
+)
12
+
13
+type connectionID int
14
+
15
+type connection struct {
16
+	conn    conntypes.PacketReadWriteCloser
17
+	mutex   sync.RWMutex
18
+	id      connectionID
19
+	hub     *connectionHub
20
+	pending uint
21
+	closing bool
22
+}
23
+
24
+func (c *connection) Write(packet conntypes.Packet) error {
25
+	c.mutex.Lock()
26
+	defer c.mutex.Unlock()
27
+
28
+	err := c.conn.Write(packet)
29
+	if err != nil {
30
+		// if we tried to write into a socket and it was broken, it is
31
+		// a time to reconsider the prescence of this socket at all.
32
+		//
33
+		// probably we need to remove it completely because it seems
34
+		// that connection is broken.
35
+		c.pending = 0
36
+	}
37
+	return err
38
+}
39
+
40
+func (c *connection) Read() (conntypes.Packet, error) {
41
+	packet, err := c.conn.Read()
42
+
43
+	c.mutex.Lock()
44
+	if err != nil {
45
+		c.pending--
46
+	} else {
47
+		c.pending = 0
48
+	}
49
+	c.mutex.Unlock()
50
+
51
+	return packet, err
52
+}
53
+
54
+func (c *connection) Stats() (bool, uint) {
55
+	c.mutex.RLock()
56
+	defer c.mutex.RUnlock()
57
+
58
+	return c.closing, c.pending
59
+}
60
+
61
+func (c *connection) Close() error {
62
+	c.mutex.Lock()
63
+	defer c.mutex.Unlock()
64
+
65
+	c.closing = true
66
+	return c.conn.Close()
67
+}
68
+
69
+func (c *connection) run() {
70
+	for {
71
+		packet, err := c.conn.Read()
72
+		if err != nil {
73
+			c.Close()
74
+			c.hub.brokenSocketsChan <- c.id
75
+			c.hub = nil
76
+			return
77
+		}
78
+
79
+		// TODO
80
+		if channel, ok := Registry.getChannel(conntypes.ConnID{}); ok {
81
+			go channel.write(packet) // nolint: errcheck
82
+		}
83
+	}
84
+}
85
+
86
+func newConnection(hub *connectionHub, req *protocol.TelegramRequest) (*connection, error) {
87
+	conn, err := mtproto.TelegramProtocol(req)
88
+	if err != nil {
89
+		return nil, fmt.Errorf("cannot create a new connection: %w", err)
90
+	}
91
+
92
+	rv := &connection{
93
+		conn: conn,
94
+		hub:  hub,
95
+		id:   connectionID(rand.Int()),
96
+	}
97
+	go rv.run()
98
+
99
+	return rv, nil
100
+}

+ 84
- 0
hub/connection_hub.go View File

@@ -0,0 +1,84 @@
1
+package hub
2
+
3
+import "time"
4
+
5
+const hubGCEvery = time.Minute
6
+
7
+type connectionHub struct {
8
+	sockets map[connectionID]*connection
9
+
10
+	brokenSocketsChan      chan connectionID
11
+	connectionRequestsChan chan *connectionHubRequest
12
+	returnConnectionsChan  chan *connection
13
+}
14
+
15
+func (h *connectionHub) run() {
16
+	gcTicker := time.NewTicker(hubGCEvery)
17
+	defer gcTicker.Stop()
18
+
19
+	for {
20
+		select {
21
+		case <-gcTicker.C:
22
+			h.runGC()
23
+		case id := <-h.brokenSocketsChan:
24
+			h.runBrokenConnection(id)
25
+		case request := <-h.connectionRequestsChan:
26
+			h.runConnectionRequest(request)
27
+		case conn := <-h.returnConnectionsChan:
28
+			h.runReturnConnection(conn)
29
+		}
30
+	}
31
+}
32
+
33
+func (h *connectionHub) runBrokenConnection(id connectionID) {
34
+	delete(h.sockets, id)
35
+}
36
+
37
+func (h *connectionHub) runGC() {
38
+	for key, conn := range h.sockets {
39
+		closing, pending := conn.Stats()
40
+		switch {
41
+		case closing:
42
+			delete(h.sockets, key)
43
+		case pending == 0:
44
+			conn.Close()
45
+			delete(h.sockets, key)
46
+			return
47
+		}
48
+
49
+	}
50
+}
51
+
52
+func (h *connectionHub) runConnectionRequest(req *connectionHubRequest) {
53
+	for key, conn := range h.sockets {
54
+		closing, _ := conn.Stats()
55
+		delete(h.sockets, key)
56
+
57
+		if !closing {
58
+			req.responseChan <- conn
59
+			return
60
+		}
61
+	}
62
+
63
+	newConn, err := newConnection(h, req.req)
64
+	if err != nil {
65
+		close(req.responseChan)
66
+		return
67
+	}
68
+
69
+	req.responseChan <- newConn
70
+}
71
+
72
+func (h *connectionHub) runReturnConnection(conn *connection) {
73
+	h.sockets[conn.id] = conn
74
+}
75
+
76
+func newConnectionHub() *connectionHub {
77
+	return &connectionHub{
78
+		sockets: map[connectionID]*connection{},
79
+
80
+		brokenSocketsChan:      make(chan connectionID, 1),
81
+		connectionRequestsChan: make(chan *connectionHubRequest),
82
+		returnConnectionsChan:  make(chan *connection, 1),
83
+	}
84
+}

+ 8
- 0
hub/connection_hub_request.go View File

@@ -0,0 +1,8 @@
1
+package hub
2
+
3
+import "github.com/9seconds/mtg/protocol"
4
+
5
+type connectionHubRequest struct {
6
+	req          *protocol.TelegramRequest
7
+	responseChan chan<- *connection
8
+}

+ 48
- 0
hub/hub.go View File

@@ -0,0 +1,48 @@
1
+package hub
2
+
3
+import (
4
+	"errors"
5
+	"sync"
6
+
7
+	"github.com/9seconds/mtg/conntypes"
8
+	"github.com/9seconds/mtg/protocol"
9
+)
10
+
11
+type Concentrator struct {
12
+	hubs sync.Map
13
+}
14
+
15
+func (c *Concentrator) Write(packet conntypes.Packet, req *protocol.TelegramRequest) error {
16
+	hub := c.getHub(req)
17
+	connectionChan := make(chan *connection)
18
+	hub.connectionRequestsChan <- &connectionHubRequest{
19
+		req:          req,
20
+		responseChan: connectionChan,
21
+	}
22
+
23
+	conn, ok := <-connectionChan
24
+	if !ok {
25
+		return errors.New("cannot establish connection to telegram")
26
+	}
27
+}
28
+
29
+func (c *Concentrator) getHub(req *protocol.TelegramRequest) *connectionHub {
30
+	dcMapRaw, ok := c.hubs.Load(req.ClientProtocol.DC())
31
+	if !ok {
32
+		dcMapRaw, _ = c.hubs.LoadOrStore(req.ClientProtocol.DC(), &sync.Map{})
33
+	}
34
+	dcMap := dcMapRaw.(*sync.Map)
35
+
36
+	loaded := true
37
+	hubRaw, ok := dcMap.Load(req.ClientProtocol.ConnectionProtocol())
38
+	if !ok {
39
+		hubRaw, loaded = dcMap.LoadOrStore(req.ClientProtocol.ConnectionProtocol(),
40
+			newConnectionHub())
41
+	}
42
+	hub := hubRaw.(*connectionHub)
43
+	if !loaded {
44
+		go hub.run()
45
+	}
46
+
47
+	return hub
48
+}

+ 53
- 0
hub/registry.go View File

@@ -0,0 +1,53 @@
1
+package hub
2
+
3
+import (
4
+	"context"
5
+	"sync"
6
+
7
+	"github.com/9seconds/mtg/conntypes"
8
+)
9
+
10
+var Registry *RegistryStruct
11
+
12
+type RegistryStruct struct {
13
+	conns map[string]*closeableChannel
14
+	ctx   context.Context
15
+	mutex sync.RWMutex
16
+}
17
+
18
+func (r *RegistryStruct) Register(id conntypes.ConnID) ChannelReadCloser {
19
+	channel := newCloseableChannel(r.ctx)
20
+
21
+	r.mutex.Lock()
22
+	r.conns[string(id[:])] = channel
23
+	r.mutex.Unlock()
24
+
25
+	return channel
26
+}
27
+
28
+func (r *RegistryStruct) Unregister(id conntypes.ConnID) {
29
+	r.mutex.Lock()
30
+	defer r.mutex.Unlock()
31
+
32
+	if channel, ok := r.conns[string(id[:])]; ok {
33
+		channel.Close()
34
+		delete(r.conns, string(id[:]))
35
+	}
36
+}
37
+
38
+func (r *RegistryStruct) getChannel(id conntypes.ConnID) (*closeableChannel, bool) {
39
+	r.mutex.RLock()
40
+	defer r.mutex.RUnlock()
41
+
42
+	if value, ok := r.conns[string(id[:])]; ok {
43
+		return value, true
44
+	}
45
+	return nil, false
46
+}
47
+
48
+func InitRegistry(ctx context.Context) {
49
+	Registry = &RegistryStruct{
50
+		ctx:   ctx,
51
+		conns: map[string]*closeableChannel{},
52
+	}
53
+}

+ 7
- 8
mtproto/protocol.go View File

@@ -3,16 +3,15 @@ package mtproto
3 3
 import (
4 4
 	"fmt"
5 5
 
6
+	"github.com/9seconds/mtg/conntypes"
6 7
 	"github.com/9seconds/mtg/mtproto/rpc"
7 8
 	"github.com/9seconds/mtg/protocol"
8 9
 	"github.com/9seconds/mtg/telegram"
9 10
 	"github.com/9seconds/mtg/wrappers"
10 11
 )
11 12
 
12
-func TelegramProtocol(req *protocol.TelegramRequest) (wrappers.Wrap, error) {
13
-	conn, err := telegram.Middle.Dial(req.Ctx,
14
-		req.Cancel,
15
-		req.ClientProtocol.DC(),
13
+func TelegramProtocol(req *protocol.TelegramRequest) (conntypes.PacketReadWriteCloser, error) {
14
+	conn, err := telegram.Middle.Dial(req.ClientProtocol.DC(),
16 15
 		req.ClientProtocol.ConnectionProtocol())
17 16
 	if err != nil {
18 17
 		return nil, fmt.Errorf("cannot connect to telegram: %w", err)
@@ -42,7 +41,7 @@ func TelegramProtocol(req *protocol.TelegramRequest) (wrappers.Wrap, error) {
42 41
 	return frameConn, nil
43 42
 }
44 43
 
45
-func doRPCNonceRequest(conn wrappers.PacketWriter) (*rpc.NonceRequest, error) {
44
+func doRPCNonceRequest(conn conntypes.PacketWriter) (*rpc.NonceRequest, error) {
46 45
 	rpcNonceReq, err := rpc.NewNonceRequest(telegram.Middle.Secret())
47 46
 	if err != nil {
48 47
 		panic(err)
@@ -54,7 +53,7 @@ func doRPCNonceRequest(conn wrappers.PacketWriter) (*rpc.NonceRequest, error) {
54 53
 	return rpcNonceReq, nil
55 54
 }
56 55
 
57
-func getRPCNonceResponse(conn wrappers.PacketReader, req *rpc.NonceRequest) (*rpc.NonceResponse, error) {
56
+func getRPCNonceResponse(conn conntypes.PacketReader, req *rpc.NonceRequest) (*rpc.NonceResponse, error) {
58 57
 	packet, err := conn.Read()
59 58
 	if err != nil {
60 59
 		return nil, fmt.Errorf("cannot read from connection: %w", err)
@@ -71,14 +70,14 @@ func getRPCNonceResponse(conn wrappers.PacketReader, req *rpc.NonceRequest) (*rp
71 70
 	return resp, nil
72 71
 }
73 72
 
74
-func doRPCHandshakeRequest(conn wrappers.PacketWriter) error {
73
+func doRPCHandshakeRequest(conn conntypes.PacketWriter) error {
75 74
 	if err := conn.Write(rpc.HandshakeRequest); err != nil {
76 75
 		return fmt.Errorf("cannot make a request: %w", err)
77 76
 	}
78 77
 	return nil
79 78
 }
80 79
 
81
-func getRPCHandshakeResponse(conn wrappers.PacketReader) error {
80
+func getRPCHandshakeResponse(conn conntypes.PacketReader) error {
82 81
 	packet, err := conn.Read()
83 82
 	if err != nil {
84 83
 		return fmt.Errorf("cannot read a response: %w", err)

+ 3
- 3
obfuscated2/client_protocol.go View File

@@ -37,7 +37,7 @@ func (c *ClientProtocol) DC() conntypes.DC {
37 37
 	return c.dc
38 38
 }
39 39
 
40
-func (c *ClientProtocol) Handshake(socket wrappers.StreamReadWriteCloser) (wrappers.StreamReadWriteCloser, error) {
40
+func (c *ClientProtocol) Handshake(socket conntypes.StreamReadWriteCloser) (conntypes.StreamReadWriteCloser, error) {
41 41
 	fm, err := c.ReadFrame(socket)
42 42
 	if err != nil {
43 43
 		return nil, fmt.Errorf("cannot make a client handshake: %w", err)
@@ -88,7 +88,7 @@ func (c *ClientProtocol) Handshake(socket wrappers.StreamReadWriteCloser) (wrapp
88 88
 	return wrappers.NewObfuscated2(socket, encryptor, decryptor), nil
89 89
 }
90 90
 
91
-func (c *ClientProtocol) ReadFrame(socket wrappers.StreamReader) (fm Frame, err error) {
91
+func (c *ClientProtocol) ReadFrame(socket conntypes.StreamReader) (fm Frame, err error) {
92 92
 	if _, err = io.ReadFull(handshakeReader{socket}, fm.Bytes()); err != nil {
93 93
 		err = fmt.Errorf("cannot extract obfuscated2 frame: %w", err)
94 94
 	}
@@ -96,7 +96,7 @@ func (c *ClientProtocol) ReadFrame(socket wrappers.StreamReader) (fm Frame, err
96 96
 }
97 97
 
98 98
 type handshakeReader struct {
99
-	parent wrappers.StreamReader
99
+	parent conntypes.StreamReader
100 100
 }
101 101
 
102 102
 func (h handshakeReader) Read(p []byte) (int, error) {

+ 7
- 6
obfuscated2/telegram_protocol.go View File

@@ -4,20 +4,21 @@ import (
4 4
 	"crypto/rand"
5 5
 	"fmt"
6 6
 
7
+	"github.com/9seconds/mtg/conntypes"
7 8
 	"github.com/9seconds/mtg/protocol"
8 9
 	"github.com/9seconds/mtg/telegram"
9 10
 	"github.com/9seconds/mtg/utils"
10 11
 	"github.com/9seconds/mtg/wrappers"
11 12
 )
12 13
 
13
-func TelegramProtocol(req *protocol.TelegramRequest) (wrappers.Wrap, error) {
14
-	socket, err := telegram.Direct.Dial(req.Ctx,
15
-		req.Cancel,
16
-		req.ClientProtocol.DC(),
14
+func TelegramProtocol(req *protocol.TelegramRequest) (conntypes.StreamReadWriteCloser, error) {
15
+	conn, err := telegram.Direct.Dial(req.ClientProtocol.DC(),
17 16
 		req.ClientProtocol.ConnectionProtocol())
18 17
 	if err != nil {
19 18
 		return nil, fmt.Errorf("cannot dial to telegram: %w", err)
20 19
 	}
20
+	conn = wrappers.NewTimeout(conn)
21
+	conn = wrappers.NewCtx(req.Ctx, req.Cancel, conn)
21 22
 	fm := generateFrame(req.ClientProtocol)
22 23
 	data := fm.Bytes()
23 24
 
@@ -30,11 +31,11 @@ func TelegramProtocol(req *protocol.TelegramRequest) (wrappers.Wrap, error) {
30 31
 	encryptor.XORKeyStream(data, data)
31 32
 	copy(data[:frameOffsetIV], copyFrame[:frameOffsetIV])
32 33
 
33
-	if _, err := socket.Write(data); err != nil {
34
+	if _, err := conn.Write(data); err != nil {
34 35
 		return nil, fmt.Errorf("cannot write handshake frame to telegram: %w", err)
35 36
 	}
36 37
 
37
-	return wrappers.NewObfuscated2(socket, encryptor, decryptor), nil
38
+	return wrappers.NewObfuscated2(conn, encryptor, decryptor), nil
38 39
 }
39 40
 
40 41
 func generateFrame(cp protocol.ClientProtocol) (fm Frame) {

+ 2
- 6
protocol/interfaces.go View File

@@ -1,16 +1,12 @@
1 1
 package protocol
2 2
 
3
-import (
4
-	"github.com/9seconds/mtg/conntypes"
5
-	"github.com/9seconds/mtg/wrappers"
6
-)
3
+import "github.com/9seconds/mtg/conntypes"
7 4
 
8 5
 type ClientProtocol interface {
9
-	Handshake(wrappers.StreamReadWriteCloser) (wrappers.StreamReadWriteCloser, error)
6
+	Handshake(conntypes.StreamReadWriteCloser) (conntypes.StreamReadWriteCloser, error)
10 7
 	ConnectionType() conntypes.ConnectionType
11 8
 	ConnectionProtocol() conntypes.ConnectionProtocol
12 9
 	DC() conntypes.DC
13 10
 }
14 11
 
15
-type TelegramProtocol func(*TelegramRequest) (wrappers.Wrap, error)
16 12
 type ClientProtocolMaker func() ClientProtocol

+ 1
- 2
protocol/request.go View File

@@ -6,12 +6,11 @@ import (
6 6
 	"go.uber.org/zap"
7 7
 
8 8
 	"github.com/9seconds/mtg/conntypes"
9
-	"github.com/9seconds/mtg/wrappers"
10 9
 )
11 10
 
12 11
 type TelegramRequest struct {
13 12
 	Logger         *zap.SugaredLogger
14
-	ClientConn     wrappers.StreamReadWriteCloser
13
+	ClientConn     conntypes.StreamReadWriteCloser
15 14
 	ConnID         conntypes.ConnID
16 15
 	Ctx            context.Context
17 16
 	Cancel         context.CancelFunc

+ 10
- 9
proxy/proxy.go View File

@@ -63,25 +63,26 @@ func (p *Proxy) accept(conn net.Conn) {
63 63
 	ctx, cancel := context.WithCancel(p.Context)
64 64
 	defer cancel()
65 65
 
66
-	wrappedConn := wrappers.NewClientConn(ctx, cancel, conn, connID)
67
-	wrappedConn = wrappers.NewTraffic(wrappedConn)
68
-	defer wrappedConn.Close()
66
+	clientConn := wrappers.NewClientConn(conn, connID)
67
+	clientConn = wrappers.NewCtx(ctx, cancel, clientConn)
68
+	clientConn = wrappers.NewTimeout(clientConn)
69
+	clientConn = wrappers.NewTraffic(clientConn)
70
+	defer clientConn.Close()
69 71
 
70 72
 	clientProtocol := p.ClientProtocolMaker()
71
-	wrappedConn, err := clientProtocol.Handshake(wrappedConn)
73
+	clientConn, err := clientProtocol.Handshake(clientConn)
72 74
 	if err != nil {
73 75
 		logger.Warnw("Cannot perform client handshake", "error", err)
74 76
 		return
75 77
 	}
76
-	defer wrappedConn.Close()
77 78
 
78
-	stats.S.ClientConnected(clientProtocol.ConnectionType(), wrappedConn.RemoteAddr())
79
-	defer stats.S.ClientDisconnected(clientProtocol.ConnectionType(), wrappedConn.RemoteAddr())
79
+	stats.S.ClientConnected(clientProtocol.ConnectionType(), clientConn.RemoteAddr())
80
+	defer stats.S.ClientDisconnected(clientProtocol.ConnectionType(), clientConn.RemoteAddr())
80 81
 	logger.Infow("Client connected", "addr", conn.RemoteAddr())
81 82
 
82 83
 	req := &protocol.TelegramRequest{
83 84
 		Logger:         logger,
84
-		ClientConn:     wrappedConn,
85
+		ClientConn:     clientConn,
85 86
 		ConnID:         connID,
86 87
 		Ctx:            ctx,
87 88
 		Cancel:         cancel,
@@ -102,7 +103,7 @@ func (p *Proxy) acceptDirectConnection(request *protocol.TelegramRequest) error
102 103
 	if err != nil {
103 104
 		return err
104 105
 	}
105
-	telegramConn := telegramConnRaw.(wrappers.StreamReadWriteCloser)
106
+	telegramConn := telegramConnRaw.(conntypes.StreamReadWriteCloser)
106 107
 	defer telegramConn.Close()
107 108
 
108 109
 	wg := &sync.WaitGroup{}

+ 16
- 22
telegram/base.go View File

@@ -1,7 +1,6 @@
1 1
 package telegram
2 2
 
3 3
 import (
4
-	"context"
5 4
 	"fmt"
6 5
 	"math/rand"
7 6
 	"net"
@@ -28,25 +27,8 @@ func (b *baseTelegram) Secret() []byte {
28 27
 	return b.secret
29 28
 }
30 29
 
31
-func (b *baseTelegram) dialToAddress(ctx context.Context,
32
-	cancel context.CancelFunc,
33
-	addr string) (wrappers.StreamReadWriteCloser, error) {
34
-	conn, err := b.dialer.Dial("tcp", addr)
35
-	if err != nil {
36
-		return nil, fmt.Errorf("dial has failed: %w", err)
37
-	}
38
-
39
-	if err := utils.InitTCP(conn); err != nil {
40
-		return nil, fmt.Errorf("cannot initialize tcp socket: %w", err)
41
-	}
42
-
43
-	return wrappers.NewTelegramConn(ctx, cancel, conn), nil
44
-}
45
-
46
-func (b *baseTelegram) dial(ctx context.Context,
47
-	cancel context.CancelFunc,
48
-	dc conntypes.DC,
49
-	protocol conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
30
+func (b *baseTelegram) dial(dc conntypes.DC,
31
+	protocol conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error) {
50 32
 	addr := ""
51 33
 
52 34
 	switch protocol {
@@ -56,7 +38,16 @@ func (b *baseTelegram) dial(ctx context.Context,
56 38
 		addr = b.chooseAddress(b.v6Addresses, dc, b.V6DefaultDC)
57 39
 	}
58 40
 
59
-	return b.dialToAddress(ctx, cancel, addr)
41
+	conn, err := b.dialer.Dial("tcp", addr)
42
+	if err != nil {
43
+		return nil, fmt.Errorf("dial has failed: %w", err)
44
+	}
45
+
46
+	if err := utils.InitTCP(conn); err != nil {
47
+		return nil, fmt.Errorf("cannot initialize tcp socket: %w", err)
48
+	}
49
+
50
+	return wrappers.NewTelegramConn(conn), nil
60 51
 }
61 52
 
62 53
 func (b *baseTelegram) chooseAddress(addresses map[conntypes.DC][]string,
@@ -66,7 +57,10 @@ func (b *baseTelegram) chooseAddress(addresses map[conntypes.DC][]string,
66 57
 		addrs, _ = addresses[defaultDC]
67 58
 	}
68 59
 
69
-	if len(addrs) > 0 {
60
+	switch {
61
+	case len(addrs) == 1:
62
+		return addrs[0]
63
+	case len(addrs) > 1:
70 64
 		return addrs[rand.Intn(len(addrs))]
71 65
 	}
72 66
 

+ 11
- 19
telegram/direct.go View File

@@ -1,15 +1,11 @@
1 1
 package telegram
2 2
 
3 3
 import (
4
-	"context"
5 4
 	"net"
6 5
 
7 6
 	"github.com/9seconds/mtg/conntypes"
8
-	"github.com/9seconds/mtg/wrappers"
9 7
 )
10 8
 
11
-var Direct = newDirectTelegram()
12
-
13 9
 const (
14 10
 	directV4DefaultIdx conntypes.DC = 1
15 11
 	directV6DefaultIdx conntypes.DC = 1
@@ -36,10 +32,8 @@ type directTelegram struct {
36 32
 	baseTelegram
37 33
 }
38 34
 
39
-func (d *directTelegram) Dial(ctx context.Context,
40
-	cancel context.CancelFunc,
41
-	dc conntypes.DC,
42
-	protocol conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
35
+func (d *directTelegram) Dial(dc conntypes.DC,
36
+	protocol conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error) {
43 37
 	switch {
44 38
 	case dc < 0:
45 39
 		dc = -dc
@@ -47,17 +41,15 @@ func (d *directTelegram) Dial(ctx context.Context,
47 41
 		dc = conntypes.DCDefaultIdx
48 42
 	}
49 43
 
50
-	return d.baseTelegram.dial(ctx, cancel, dc-1, protocol)
44
+	return d.baseTelegram.dial(dc-1, protocol)
51 45
 }
52 46
 
53
-func newDirectTelegram() Telegram {
54
-	return &directTelegram{
55
-		baseTelegram: baseTelegram{
56
-			dialer:      net.Dialer{Timeout: telegramDialTimeout},
57
-			v4DefaultDC: directV4DefaultIdx,
58
-			V6DefaultDC: directV6DefaultIdx,
59
-			v4Addresses: directV4Addresses,
60
-			v6Addresses: directV6Addresses,
61
-		},
62
-	}
47
+var Direct = &directTelegram{
48
+	baseTelegram: baseTelegram{
49
+		dialer:      net.Dialer{Timeout: telegramDialTimeout},
50
+		v4DefaultDC: directV4DefaultIdx,
51
+		V6DefaultDC: directV6DefaultIdx,
52
+		v4Addresses: directV4Addresses,
53
+		v6Addresses: directV6Addresses,
54
+	},
63 55
 }

+ 2
- 10
telegram/interfaces.go View File

@@ -1,16 +1,8 @@
1 1
 package telegram
2 2
 
3
-import (
4
-	"context"
5
-
6
-	"github.com/9seconds/mtg/conntypes"
7
-	"github.com/9seconds/mtg/wrappers"
8
-)
3
+import  "github.com/9seconds/mtg/conntypes"
9 4
 
10 5
 type Telegram interface {
11
-	Dial(context.Context,
12
-		context.CancelFunc,
13
-		conntypes.DC,
14
-		conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error)
6
+	Dial(conntypes.DC, conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error)
15 7
 	Secret() []byte
16 8
 }

+ 3
- 7
telegram/middle.go View File

@@ -1,7 +1,6 @@
1 1
 package telegram
2 2
 
3 3
 import (
4
-	"context"
5 4
 	"fmt"
6 5
 	"net"
7 6
 	"sync"
@@ -11,7 +10,6 @@ import (
11 10
 
12 11
 	"github.com/9seconds/mtg/conntypes"
13 12
 	"github.com/9seconds/mtg/telegram/api"
14
-	"github.com/9seconds/mtg/wrappers"
15 13
 )
16 14
 
17 15
 const middleTelegramBackgroundUpdateEvery = time.Hour
@@ -67,10 +65,8 @@ func (m *middleTelegram) backgroundUpdate() {
67 65
 	}
68 66
 }
69 67
 
70
-func (m *middleTelegram) Dial(ctx context.Context,
71
-	cancel context.CancelFunc,
72
-	dc conntypes.DC,
73
-	protocol conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
68
+func (m *middleTelegram) Dial(dc conntypes.DC,
69
+	protocol conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error) {
74 70
 	if dc == 0 {
75 71
 		dc = conntypes.DCDefaultIdx
76 72
 	}
@@ -78,7 +74,7 @@ func (m *middleTelegram) Dial(ctx context.Context,
78 74
 	m.mutex.RLock()
79 75
 	defer m.mutex.RUnlock()
80 76
 
81
-	return m.baseTelegram.dial(ctx, cancel, dc, protocol)
77
+	return m.baseTelegram.dial(dc, protocol)
82 78
 }
83 79
 
84 80
 func MiddleInit() {

+ 10
- 6
wrappers/blockcipher.go View File

@@ -10,6 +10,8 @@ import (
10 10
 	"time"
11 11
 
12 12
 	"go.uber.org/zap"
13
+
14
+	"github.com/9seconds/mtg/conntypes"
13 15
 )
14 16
 
15 17
 const blockCipherReadCurrentDataBufferSize = 1024 + 1 // +1 because telegram operates with blocks mod 4
@@ -17,7 +19,7 @@ const blockCipherReadCurrentDataBufferSize = 1024 + 1 // +1 because telegram ope
17 19
 type wrapperBlockCipher struct {
18 20
 	buf bytes.Buffer
19 21
 
20
-	parent    StreamReadWriteCloser
22
+	parent    conntypes.StreamReadWriteCloser
21 23
 	encryptor cipher.BlockMode
22 24
 	decryptor cipher.BlockMode
23 25
 }
@@ -47,7 +49,8 @@ func (w *wrapperBlockCipher) ReadTimeout(p []byte, timeout time.Duration) (int,
47 49
 	return w.read(p, readAllTimeout(timeout))
48 50
 }
49 51
 
50
-func (w *wrapperBlockCipher) read(p []byte, reader func(StreamReadWriteCloser) ([]byte, error)) (int, error) {
52
+func (w *wrapperBlockCipher) read(p []byte,
53
+	reader func(conntypes.StreamReadWriteCloser) ([]byte, error)) (int, error) {
51 54
 	if w.buf.Len() > 0 {
52 55
 		return w.flush(p)
53 56
 	}
@@ -90,7 +93,7 @@ func (w *wrapperBlockCipher) encrypt(p []byte) ([]byte, error) {
90 93
 	return encrypted, nil
91 94
 }
92 95
 
93
-func readAll(src StreamReadWriteCloser) (rv []byte, err error) {
96
+func readAll(src conntypes.StreamReadWriteCloser) (rv []byte, err error) {
94 97
 	buf := make([]byte, blockCipherReadCurrentDataBufferSize)
95 98
 	n := blockCipherReadCurrentDataBufferSize
96 99
 
@@ -105,8 +108,8 @@ func readAll(src StreamReadWriteCloser) (rv []byte, err error) {
105 108
 	return rv, nil
106 109
 }
107 110
 
108
-func readAllTimeout(timeout time.Duration) func(StreamReadWriteCloser) ([]byte, error) {
109
-	return func(src StreamReadWriteCloser) (rv []byte, err error) {
111
+func readAllTimeout(timeout time.Duration) func(conntypes.StreamReadWriteCloser) ([]byte, error) {
112
+	return func(src conntypes.StreamReadWriteCloser) (rv []byte, err error) {
110 113
 		tmo := timeout
111 114
 		buf := make([]byte, blockCipherReadCurrentDataBufferSize)
112 115
 		n := blockCipherReadCurrentDataBufferSize
@@ -148,7 +151,8 @@ func (w *wrapperBlockCipher) RemoteAddr() *net.TCPAddr {
148 151
 	return w.parent.RemoteAddr()
149 152
 }
150 153
 
151
-func newBlockCipher(parent StreamReadWriteCloser, encryptor, decryptor cipher.BlockMode) StreamReadWriteCloser {
154
+func newBlockCipher(parent conntypes.StreamReadWriteCloser,
155
+	encryptor, decryptor cipher.BlockMode) conntypes.StreamReadWriteCloser {
152 156
 	return &wrapperBlockCipher{
153 157
 		parent:    parent,
154 158
 		encryptor: encryptor,

+ 29
- 61
wrappers/conn.go View File

@@ -1,7 +1,6 @@
1 1
 package wrappers
2 2
 
3 3
 import (
4
-	"context"
5 4
 	"fmt"
6 5
 	"net"
7 6
 	"time"
@@ -19,15 +18,8 @@ const (
19 18
 	connPurposeTelegram
20 19
 )
21 20
 
22
-const (
23
-	connTimeoutRead  = 2 * time.Minute
24
-	connTimeoutWrite = 2 * time.Minute
25
-)
26
-
27 21
 type wrapperConn struct {
28 22
 	parent     net.Conn
29
-	ctx        context.Context
30
-	cancel     context.CancelFunc
31 23
 	connID     conntypes.ConnID
32 24
 	logger     *zap.SugaredLogger
33 25
 	localAddr  *net.TCPAddr
@@ -35,61 +27,45 @@ type wrapperConn struct {
35 27
 }
36 28
 
37 29
 func (w *wrapperConn) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
38
-	select {
39
-	case <-w.ctx.Done():
30
+	if err := w.parent.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
40 31
 		w.Close()
41
-		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
42
-
43
-	default:
44
-		if err := w.parent.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
45
-			w.Close() // nolint: gosec
46
-			return 0, fmt.Errorf("cannot set write deadline to the socket: %w", err)
47
-		}
48
-
49
-		n, err := w.parent.Write(p)
50
-		w.logger.Debugw("Write to stream", "bytes", n, "error", err)
51
-		if err != nil {
52
-			w.Close() // nolint: gosec
53
-		}
54
-
55
-		return n, err
32
+		return 0, fmt.Errorf("cannot set write deadline to the socket: %w", err)
56 33
 	}
34
+
35
+	return w.Write(p)
57 36
 }
58 37
 
59 38
 func (w *wrapperConn) Write(p []byte) (int, error) {
60
-	return w.WriteTimeout(p, connTimeoutWrite)
39
+	n, err := w.parent.Write(p)
40
+	w.logger.Debugw("write to stream", "bytes", n, "error", err)
41
+	if err != nil {
42
+		w.Close() // nolint: gosec
43
+	}
44
+
45
+	return n, err
61 46
 }
62 47
 
63 48
 func (w *wrapperConn) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
64
-	select {
65
-	case <-w.ctx.Done():
49
+	if err := w.parent.SetReadDeadline(time.Now().Add(timeout)); err != nil {
66 50
 		w.Close()
67
-		return 0, fmt.Errorf("cannot read because context was closed: %w", w.ctx.Err())
68
-
69
-	default:
70
-		if err := w.parent.SetReadDeadline(time.Now().Add(timeout)); err != nil {
71
-			w.Close()
72
-			return 0, fmt.Errorf("cannot set read deadline to the socket: %w", err)
73
-		}
74
-
75
-		n, err := w.parent.Read(p)
76
-		w.logger.Debugw("Read from stream", "bytes", n, "error", err)
77
-		if err != nil {
78
-			w.Close()
79
-		}
80
-
81
-		return n, err
51
+		return 0, fmt.Errorf("cannot set read deadline to the socket: %w", err)
82 52
 	}
53
+
54
+	return w.Read(p)
83 55
 }
84 56
 
85 57
 func (w *wrapperConn) Read(p []byte) (int, error) {
86
-	return w.ReadTimeout(p, connTimeoutRead)
58
+	n, err := w.parent.Read(p)
59
+	w.logger.Debugw("Read from stream", "bytes", n, "error", err)
60
+	if err != nil {
61
+		w.Close()
62
+	}
63
+
64
+	return n, err
87 65
 }
88 66
 
89 67
 func (w *wrapperConn) Close() error {
90 68
 	w.logger.Debugw("Close connection")
91
-	w.cancel()
92
-
93 69
 	return w.parent.Close()
94 70
 }
95 71
 
@@ -109,11 +85,9 @@ func (w *wrapperConn) RemoteAddr() *net.TCPAddr {
109 85
 	return w.remoteAddr
110 86
 }
111 87
 
112
-func newConn(ctx context.Context,
113
-	cancel context.CancelFunc,
114
-	parent net.Conn,
88
+func newConn(parent net.Conn,
115 89
 	connID conntypes.ConnID,
116
-	purpose connPurpose) StreamReadWriteCloser {
90
+	purpose connPurpose) conntypes.StreamReadWriteCloser {
117 91
 	localAddr := *parent.LocalAddr().(*net.TCPAddr)
118 92
 
119 93
 	if parent.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
@@ -135,8 +109,6 @@ func newConn(ctx context.Context,
135 109
 
136 110
 	return &wrapperConn{
137 111
 		parent:     parent,
138
-		ctx:        ctx,
139
-		cancel:     cancel,
140 112
 		connID:     connID,
141 113
 		logger:     logger,
142 114
 		remoteAddr: parent.RemoteAddr().(*net.TCPAddr),
@@ -144,15 +116,11 @@ func newConn(ctx context.Context,
144 116
 	}
145 117
 }
146 118
 
147
-func NewClientConn(ctx context.Context,
148
-	cancel context.CancelFunc,
149
-	parent net.Conn,
150
-	connID conntypes.ConnID) StreamReadWriteCloser {
151
-	return newConn(ctx, cancel, parent, connID, connPurposeClient)
119
+func NewClientConn(parent net.Conn,
120
+	connID conntypes.ConnID) conntypes.StreamReadWriteCloser {
121
+	return newConn(parent, connID, connPurposeClient)
152 122
 }
153 123
 
154
-func NewTelegramConn(ctx context.Context,
155
-	cancel context.CancelFunc,
156
-	parent net.Conn) StreamReadWriteCloser {
157
-	return newConn(ctx, cancel, parent, conntypes.ConnID{}, connPurposeTelegram)
124
+func NewTelegramConn(parent net.Conn) conntypes.StreamReadWriteCloser {
125
+	return newConn(parent, conntypes.ConnID{}, connPurposeTelegram)
158 126
 }

+ 89
- 0
wrappers/ctx.go View File

@@ -0,0 +1,89 @@
1
+package wrappers
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"net"
7
+	"time"
8
+
9
+	"go.uber.org/zap"
10
+
11
+	"github.com/9seconds/mtg/conntypes"
12
+)
13
+
14
+type wrapperCtx struct {
15
+	parent conntypes.StreamReadWriteCloser
16
+	ctx    context.Context
17
+	cancel context.CancelFunc
18
+}
19
+
20
+func (w *wrapperCtx) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
21
+	select {
22
+	case <-w.ctx.Done():
23
+		w.Close()
24
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
25
+	default:
26
+		return w.parent.WriteTimeout(p, timeout)
27
+	}
28
+}
29
+
30
+func (w *wrapperCtx) Write(p []byte) (int, error) {
31
+	select {
32
+	case <-w.ctx.Done():
33
+		w.Close()
34
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
35
+	default:
36
+		return w.parent.Write(p)
37
+	}
38
+}
39
+
40
+func (w *wrapperCtx) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
41
+	select {
42
+	case <-w.ctx.Done():
43
+		w.Close()
44
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
45
+	default:
46
+		return w.parent.ReadTimeout(p, timeout)
47
+	}
48
+}
49
+
50
+func (w *wrapperCtx) Read(p []byte) (int, error) {
51
+	select {
52
+	case <-w.ctx.Done():
53
+		w.Close()
54
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
55
+	default:
56
+		return w.parent.Read(p)
57
+	}
58
+}
59
+
60
+func (w *wrapperCtx) Close() error {
61
+	w.cancel()
62
+	return w.parent.Close()
63
+}
64
+
65
+func (w *wrapperCtx) Conn() net.Conn {
66
+	return w.parent.Conn()
67
+}
68
+
69
+func (w *wrapperCtx) Logger() *zap.SugaredLogger {
70
+	return w.parent.Logger().Named("ctx")
71
+}
72
+
73
+func (w *wrapperCtx) LocalAddr() *net.TCPAddr {
74
+	return w.parent.LocalAddr()
75
+}
76
+
77
+func (w *wrapperCtx) RemoteAddr() *net.TCPAddr {
78
+	return w.parent.RemoteAddr()
79
+}
80
+
81
+func NewCtx(ctx context.Context,
82
+	cancel context.CancelFunc,
83
+	parent conntypes.StreamReadWriteCloser) conntypes.StreamReadWriteCloser {
84
+	return &wrapperCtx{
85
+		parent: parent,
86
+		ctx:    ctx,
87
+		cancel: cancel,
88
+	}
89
+}

+ 3
- 2
wrappers/mtproto_cipher.go View File

@@ -9,6 +9,7 @@ import (
9 9
 	"encoding/binary"
10 10
 	"net"
11 11
 
12
+	"github.com/9seconds/mtg/conntypes"
12 13
 	"github.com/9seconds/mtg/mtproto/rpc"
13 14
 	"github.com/9seconds/mtg/utils"
14 15
 )
@@ -22,10 +23,10 @@ const (
22 23
 
23 24
 var mtprotoEmptyIP = [4]byte{0x00, 0x00, 0x00, 0x00}
24 25
 
25
-func NewMiddleProxyCipher(parent StreamReadWriteCloser,
26
+func NewMiddleProxyCipher(parent conntypes.StreamReadWriteCloser,
26 27
 	req *rpc.NonceRequest,
27 28
 	resp *rpc.NonceResponse,
28
-	secret []byte) StreamReadWriteCloser {
29
+	secret []byte) conntypes.StreamReadWriteCloser {
29 30
 	localAddr := parent.LocalAddr()
30 31
 	remoteAddr := parent.RemoteAddr()
31 32
 

+ 6
- 4
wrappers/mtproto_frame.go View File

@@ -11,6 +11,8 @@ import (
11 11
 	"net"
12 12
 
13 13
 	"go.uber.org/zap"
14
+
15
+	"github.com/9seconds/mtg/conntypes"
14 16
 )
15 17
 
16 18
 const (
@@ -33,13 +35,13 @@ var mtprotoFramePadding = []byte{0x04, 0x00, 0x00, 0x00}
33 35
 // PADDING is custom padding schema to complete frame length to such that
34 36
 //    len(frame) % 16 == 0
35 37
 type wrapperMtprotoFrame struct {
36
-	parent     StreamReadWriteCloser
38
+	parent     conntypes.StreamReadWriteCloser
37 39
 	logger     *zap.SugaredLogger
38 40
 	readSeqNo  int32
39 41
 	writeSeqNo int32
40 42
 }
41 43
 
42
-func (w *wrapperMtprotoFrame) Read() (Packet, error) {
44
+func (w *wrapperMtprotoFrame) Read() (conntypes.Packet, error) {
43 45
 	buf := &bytes.Buffer{}
44 46
 	sum := crc32.NewIEEE()
45 47
 	writer := io.MultiWriter(buf, sum)
@@ -101,7 +103,7 @@ func (w *wrapperMtprotoFrame) Read() (Packet, error) {
101 103
 	return data, nil
102 104
 }
103 105
 
104
-func (w *wrapperMtprotoFrame) Write(p Packet) error {
106
+func (w *wrapperMtprotoFrame) Write(p conntypes.Packet) error {
105 107
 	messageLength := 4 + 4 + len(p) + 4
106 108
 	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
107 109
 
@@ -149,7 +151,7 @@ func (w *wrapperMtprotoFrame) RemoteAddr() *net.TCPAddr {
149 151
 	return w.parent.RemoteAddr()
150 152
 }
151 153
 
152
-func NewMtprotoFrame(parent StreamReadWriteCloser, seqNo int32) PacketReadWriteCloser {
154
+func NewMtprotoFrame(parent conntypes.StreamReadWriteCloser, seqNo int32) conntypes.PacketReadWriteCloser {
153 155
 	return &wrapperMtprotoFrame{
154 156
 		parent:     parent,
155 157
 		logger:     parent.Logger().Named("mtproto-frame"),

+ 5
- 2
wrappers/obfuscated2.go View File

@@ -7,12 +7,14 @@ import (
7 7
 	"time"
8 8
 
9 9
 	"go.uber.org/zap"
10
+
11
+	"github.com/9seconds/mtg/conntypes"
10 12
 )
11 13
 
12 14
 type wrapperObfuscated2 struct {
13 15
 	encryptor cipher.Stream
14 16
 	decryptor cipher.Stream
15
-	parent    StreamReadWriteCloser
17
+	parent    conntypes.StreamReadWriteCloser
16 18
 }
17 19
 
18 20
 func (w *wrapperObfuscated2) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
@@ -71,7 +73,8 @@ func (w *wrapperObfuscated2) Close() error {
71 73
 	return w.parent.Close()
72 74
 }
73 75
 
74
-func NewObfuscated2(socket StreamReadWriteCloser, encryptor, decryptor cipher.Stream) StreamReadWriteCloser {
76
+func NewObfuscated2(socket conntypes.StreamReadWriteCloser,
77
+	encryptor, decryptor cipher.Stream) conntypes.StreamReadWriteCloser {
75 78
 	return &wrapperObfuscated2{
76 79
 		parent:    socket,
77 80
 		encryptor: encryptor,

+ 3
- 2
wrappers/stats.go View File

@@ -6,11 +6,12 @@ import (
6 6
 
7 7
 	"go.uber.org/zap"
8 8
 
9
+	"github.com/9seconds/mtg/conntypes"
9 10
 	"github.com/9seconds/mtg/stats"
10 11
 )
11 12
 
12 13
 type wrapperStats struct {
13
-	parent StreamReadWriteCloser
14
+	parent conntypes.StreamReadWriteCloser
14 15
 }
15 16
 
16 17
 func (w *wrapperStats) Write(p []byte) (int, error) {
@@ -61,6 +62,6 @@ func (w *wrapperStats) Close() error {
61 62
 	return w.parent.Close()
62 63
 }
63 64
 
64
-func NewTraffic(parent StreamReadWriteCloser) StreamReadWriteCloser {
65
+func NewTraffic(parent conntypes.StreamReadWriteCloser) conntypes.StreamReadWriteCloser {
65 66
 	return &wrapperStats{parent}
66 67
 }

+ 61
- 0
wrappers/timeout.go View File

@@ -0,0 +1,61 @@
1
+package wrappers
2
+
3
+import (
4
+	"net"
5
+	"time"
6
+
7
+	"go.uber.org/zap"
8
+
9
+	"github.com/9seconds/mtg/conntypes"
10
+)
11
+
12
+const (
13
+	timeoutRead  = 2 * time.Minute
14
+	timeoutWrite = 2 * time.Minute
15
+)
16
+
17
+type wrapperTimeout struct {
18
+	parent conntypes.StreamReadWriteCloser
19
+}
20
+
21
+func (w *wrapperTimeout) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
22
+	return w.parent.WriteTimeout(p, timeout)
23
+}
24
+
25
+func (w *wrapperTimeout) Write(p []byte) (int, error) {
26
+	return w.parent.WriteTimeout(p, timeoutWrite)
27
+}
28
+
29
+func (w *wrapperTimeout) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
30
+	return w.parent.ReadTimeout(p, timeout)
31
+}
32
+
33
+func (w *wrapperTimeout) Read(p []byte) (int, error) {
34
+	return w.parent.ReadTimeout(p, timeoutRead)
35
+}
36
+
37
+func (w *wrapperTimeout) Close() error {
38
+	return w.parent.Close()
39
+}
40
+
41
+func (w *wrapperTimeout) Conn() net.Conn {
42
+	return w.parent.Conn()
43
+}
44
+
45
+func (w *wrapperTimeout) Logger() *zap.SugaredLogger {
46
+	return w.parent.Logger().Named("timeout")
47
+}
48
+
49
+func (w *wrapperTimeout) LocalAddr() *net.TCPAddr {
50
+	return w.parent.LocalAddr()
51
+}
52
+
53
+func (w *wrapperTimeout) RemoteAddr() *net.TCPAddr {
54
+	return w.parent.RemoteAddr()
55
+}
56
+
57
+func NewTimeout(parent conntypes.StreamReadWriteCloser) conntypes.StreamReadWriteCloser {
58
+	return &wrapperTimeout{
59
+		parent: parent,
60
+	}
61
+}

Loading…
Cancel
Save