ソースを参照

wip

tags/1.0^2
9seconds 6年前
コミット
c9743b5675

+ 3
- 0
conntypes/packet.go ファイルの表示

1
+package conntypes
2
+
3
+type Packet []byte

wrappers/interfaces.go → conntypes/wrappers.go ファイルの表示

1
-package wrappers
1
+package conntypes
2
 
2
 
3
 import (
3
 import (
4
 	"io"
4
 	"io"
8
 	"go.uber.org/zap"
8
 	"go.uber.org/zap"
9
 )
9
 )
10
 
10
 
11
-type Packet []byte
12
-
13
 // Wrap is a base interface for all wrappers in this package.
11
 // Wrap is a base interface for all wrappers in this package.
14
 type Wrap interface {
12
 type Wrap interface {
15
 	Conn() net.Conn
13
 	Conn() net.Conn

+ 2
- 0
go.mod ファイルの表示

9
 	github.com/allegro/bigcache v1.2.1
9
 	github.com/allegro/bigcache v1.2.1
10
 	github.com/beevik/ntp v0.2.0
10
 	github.com/beevik/ntp v0.2.0
11
 	github.com/cespare/xxhash v1.1.0
11
 	github.com/cespare/xxhash v1.1.0
12
+	github.com/dustin/go-humanize v1.0.0
13
+	github.com/gammazero/deque v0.0.0-20190521012701-46e4ffb7a622
12
 	github.com/juju/errors v0.0.0-20190806202954-0232dcc7464d
14
 	github.com/juju/errors v0.0.0-20190806202954-0232dcc7464d
13
 	github.com/kr/pretty v0.1.0 // indirect
15
 	github.com/kr/pretty v0.1.0 // indirect
14
 	github.com/pkg/errors v0.8.1
16
 	github.com/pkg/errors v0.8.1

+ 4
- 0
go.sum ファイルの表示

24
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
24
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
25
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
25
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
26
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
26
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
27
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
28
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
29
+github.com/gammazero/deque v0.0.0-20190521012701-46e4ffb7a622 h1:lxbhOGZ9pU3Kf8P6lFluUcE82yVZn2EqEf4+mWRNPV0=
30
+github.com/gammazero/deque v0.0.0-20190521012701-46e4ffb7a622/go.mod h1:D90+MBHVc9Sk1lJAbEVgws0eYEurY4mv2TDso3Nxh3w=
27
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
31
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
28
 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
32
 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
29
 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
33
 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=

+ 59
- 0
hub/closeable_channel.go ファイルの表示

1
+package hub
2
+
3
+import (
4
+	"context"
5
+	"errors"
6
+	"time"
7
+
8
+	"github.com/9seconds/mtg/conntypes"
9
+)
10
+
11
+const closeableChannelReadTimeout = 2 * time.Minute
12
+
13
+type ChannelReadCloser interface {
14
+	Read() (conntypes.Packet, error)
15
+	Close()
16
+}
17
+
18
+type closeableChannel struct {
19
+	channel chan conntypes.Packet
20
+	ctx     context.Context
21
+	cancel  context.CancelFunc
22
+}
23
+
24
+func (c *closeableChannel) Read() (conntypes.Packet, error) {
25
+	timer := time.NewTimer(closeableChannelReadTimeout)
26
+	defer timer.Stop()
27
+
28
+	select {
29
+	case <-timer.C:
30
+		return nil, errors.New("timeout")
31
+	case <-c.ctx.Done():
32
+		return nil, errors.New("channel was closed")
33
+	case packet := <-c.channel:
34
+		return packet, nil
35
+	}
36
+}
37
+
38
+func (c *closeableChannel) write(packet conntypes.Packet) error {
39
+	select {
40
+	case <-c.ctx.Done():
41
+		return errors.New("channel was closed")
42
+	case c.channel <- packet:
43
+		return nil
44
+	}
45
+}
46
+
47
+func (c *closeableChannel) Close() {
48
+	c.cancel()
49
+	c.channel = nil
50
+}
51
+
52
+func newCloseableChannel(ctx context.Context) *closeableChannel {
53
+	ctx, cancel := context.WithCancel(ctx)
54
+	return &closeableChannel{
55
+		channel: make(chan conntypes.Packet),
56
+		ctx:     ctx,
57
+		cancel:  cancel,
58
+	}
59
+}

+ 100
- 0
hub/connection.go ファイルの表示

1
+package hub
2
+
3
+import (
4
+	"fmt"
5
+	"math/rand"
6
+	"sync"
7
+
8
+	"github.com/9seconds/mtg/conntypes"
9
+	"github.com/9seconds/mtg/mtproto"
10
+	"github.com/9seconds/mtg/protocol"
11
+)
12
+
13
+type connectionID int
14
+
15
+type connection struct {
16
+	conn    conntypes.PacketReadWriteCloser
17
+	mutex   sync.RWMutex
18
+	id      connectionID
19
+	hub     *connectionHub
20
+	pending uint
21
+	closing bool
22
+}
23
+
24
+func (c *connection) Write(packet conntypes.Packet) error {
25
+	c.mutex.Lock()
26
+	defer c.mutex.Unlock()
27
+
28
+	err := c.conn.Write(packet)
29
+	if err != nil {
30
+		// if we tried to write into a socket and it was broken, it is
31
+		// a time to reconsider the prescence of this socket at all.
32
+		//
33
+		// probably we need to remove it completely because it seems
34
+		// that connection is broken.
35
+		c.pending = 0
36
+	}
37
+	return err
38
+}
39
+
40
+func (c *connection) Read() (conntypes.Packet, error) {
41
+	packet, err := c.conn.Read()
42
+
43
+	c.mutex.Lock()
44
+	if err != nil {
45
+		c.pending--
46
+	} else {
47
+		c.pending = 0
48
+	}
49
+	c.mutex.Unlock()
50
+
51
+	return packet, err
52
+}
53
+
54
+func (c *connection) Stats() (bool, uint) {
55
+	c.mutex.RLock()
56
+	defer c.mutex.RUnlock()
57
+
58
+	return c.closing, c.pending
59
+}
60
+
61
+func (c *connection) Close() error {
62
+	c.mutex.Lock()
63
+	defer c.mutex.Unlock()
64
+
65
+	c.closing = true
66
+	return c.conn.Close()
67
+}
68
+
69
+func (c *connection) run() {
70
+	for {
71
+		packet, err := c.conn.Read()
72
+		if err != nil {
73
+			c.Close()
74
+			c.hub.brokenSocketsChan <- c.id
75
+			c.hub = nil
76
+			return
77
+		}
78
+
79
+		// TODO
80
+		if channel, ok := Registry.getChannel(conntypes.ConnID{}); ok {
81
+			go channel.write(packet) // nolint: errcheck
82
+		}
83
+	}
84
+}
85
+
86
+func newConnection(hub *connectionHub, req *protocol.TelegramRequest) (*connection, error) {
87
+	conn, err := mtproto.TelegramProtocol(req)
88
+	if err != nil {
89
+		return nil, fmt.Errorf("cannot create a new connection: %w", err)
90
+	}
91
+
92
+	rv := &connection{
93
+		conn: conn,
94
+		hub:  hub,
95
+		id:   connectionID(rand.Int()),
96
+	}
97
+	go rv.run()
98
+
99
+	return rv, nil
100
+}

+ 84
- 0
hub/connection_hub.go ファイルの表示

1
+package hub
2
+
3
+import "time"
4
+
5
+const hubGCEvery = time.Minute
6
+
7
+type connectionHub struct {
8
+	sockets map[connectionID]*connection
9
+
10
+	brokenSocketsChan      chan connectionID
11
+	connectionRequestsChan chan *connectionHubRequest
12
+	returnConnectionsChan  chan *connection
13
+}
14
+
15
+func (h *connectionHub) run() {
16
+	gcTicker := time.NewTicker(hubGCEvery)
17
+	defer gcTicker.Stop()
18
+
19
+	for {
20
+		select {
21
+		case <-gcTicker.C:
22
+			h.runGC()
23
+		case id := <-h.brokenSocketsChan:
24
+			h.runBrokenConnection(id)
25
+		case request := <-h.connectionRequestsChan:
26
+			h.runConnectionRequest(request)
27
+		case conn := <-h.returnConnectionsChan:
28
+			h.runReturnConnection(conn)
29
+		}
30
+	}
31
+}
32
+
33
+func (h *connectionHub) runBrokenConnection(id connectionID) {
34
+	delete(h.sockets, id)
35
+}
36
+
37
+func (h *connectionHub) runGC() {
38
+	for key, conn := range h.sockets {
39
+		closing, pending := conn.Stats()
40
+		switch {
41
+		case closing:
42
+			delete(h.sockets, key)
43
+		case pending == 0:
44
+			conn.Close()
45
+			delete(h.sockets, key)
46
+			return
47
+		}
48
+
49
+	}
50
+}
51
+
52
+func (h *connectionHub) runConnectionRequest(req *connectionHubRequest) {
53
+	for key, conn := range h.sockets {
54
+		closing, _ := conn.Stats()
55
+		delete(h.sockets, key)
56
+
57
+		if !closing {
58
+			req.responseChan <- conn
59
+			return
60
+		}
61
+	}
62
+
63
+	newConn, err := newConnection(h, req.req)
64
+	if err != nil {
65
+		close(req.responseChan)
66
+		return
67
+	}
68
+
69
+	req.responseChan <- newConn
70
+}
71
+
72
+func (h *connectionHub) runReturnConnection(conn *connection) {
73
+	h.sockets[conn.id] = conn
74
+}
75
+
76
+func newConnectionHub() *connectionHub {
77
+	return &connectionHub{
78
+		sockets: map[connectionID]*connection{},
79
+
80
+		brokenSocketsChan:      make(chan connectionID, 1),
81
+		connectionRequestsChan: make(chan *connectionHubRequest),
82
+		returnConnectionsChan:  make(chan *connection, 1),
83
+	}
84
+}

+ 8
- 0
hub/connection_hub_request.go ファイルの表示

1
+package hub
2
+
3
+import "github.com/9seconds/mtg/protocol"
4
+
5
+type connectionHubRequest struct {
6
+	req          *protocol.TelegramRequest
7
+	responseChan chan<- *connection
8
+}

+ 48
- 0
hub/hub.go ファイルの表示

1
+package hub
2
+
3
+import (
4
+	"errors"
5
+	"sync"
6
+
7
+	"github.com/9seconds/mtg/conntypes"
8
+	"github.com/9seconds/mtg/protocol"
9
+)
10
+
11
+type Concentrator struct {
12
+	hubs sync.Map
13
+}
14
+
15
+func (c *Concentrator) Write(packet conntypes.Packet, req *protocol.TelegramRequest) error {
16
+	hub := c.getHub(req)
17
+	connectionChan := make(chan *connection)
18
+	hub.connectionRequestsChan <- &connectionHubRequest{
19
+		req:          req,
20
+		responseChan: connectionChan,
21
+	}
22
+
23
+	conn, ok := <-connectionChan
24
+	if !ok {
25
+		return errors.New("cannot establish connection to telegram")
26
+	}
27
+}
28
+
29
+func (c *Concentrator) getHub(req *protocol.TelegramRequest) *connectionHub {
30
+	dcMapRaw, ok := c.hubs.Load(req.ClientProtocol.DC())
31
+	if !ok {
32
+		dcMapRaw, _ = c.hubs.LoadOrStore(req.ClientProtocol.DC(), &sync.Map{})
33
+	}
34
+	dcMap := dcMapRaw.(*sync.Map)
35
+
36
+	loaded := true
37
+	hubRaw, ok := dcMap.Load(req.ClientProtocol.ConnectionProtocol())
38
+	if !ok {
39
+		hubRaw, loaded = dcMap.LoadOrStore(req.ClientProtocol.ConnectionProtocol(),
40
+			newConnectionHub())
41
+	}
42
+	hub := hubRaw.(*connectionHub)
43
+	if !loaded {
44
+		go hub.run()
45
+	}
46
+
47
+	return hub
48
+}

+ 53
- 0
hub/registry.go ファイルの表示

1
+package hub
2
+
3
+import (
4
+	"context"
5
+	"sync"
6
+
7
+	"github.com/9seconds/mtg/conntypes"
8
+)
9
+
10
+var Registry *RegistryStruct
11
+
12
+type RegistryStruct struct {
13
+	conns map[string]*closeableChannel
14
+	ctx   context.Context
15
+	mutex sync.RWMutex
16
+}
17
+
18
+func (r *RegistryStruct) Register(id conntypes.ConnID) ChannelReadCloser {
19
+	channel := newCloseableChannel(r.ctx)
20
+
21
+	r.mutex.Lock()
22
+	r.conns[string(id[:])] = channel
23
+	r.mutex.Unlock()
24
+
25
+	return channel
26
+}
27
+
28
+func (r *RegistryStruct) Unregister(id conntypes.ConnID) {
29
+	r.mutex.Lock()
30
+	defer r.mutex.Unlock()
31
+
32
+	if channel, ok := r.conns[string(id[:])]; ok {
33
+		channel.Close()
34
+		delete(r.conns, string(id[:]))
35
+	}
36
+}
37
+
38
+func (r *RegistryStruct) getChannel(id conntypes.ConnID) (*closeableChannel, bool) {
39
+	r.mutex.RLock()
40
+	defer r.mutex.RUnlock()
41
+
42
+	if value, ok := r.conns[string(id[:])]; ok {
43
+		return value, true
44
+	}
45
+	return nil, false
46
+}
47
+
48
+func InitRegistry(ctx context.Context) {
49
+	Registry = &RegistryStruct{
50
+		ctx:   ctx,
51
+		conns: map[string]*closeableChannel{},
52
+	}
53
+}

+ 7
- 8
mtproto/protocol.go ファイルの表示

3
 import (
3
 import (
4
 	"fmt"
4
 	"fmt"
5
 
5
 
6
+	"github.com/9seconds/mtg/conntypes"
6
 	"github.com/9seconds/mtg/mtproto/rpc"
7
 	"github.com/9seconds/mtg/mtproto/rpc"
7
 	"github.com/9seconds/mtg/protocol"
8
 	"github.com/9seconds/mtg/protocol"
8
 	"github.com/9seconds/mtg/telegram"
9
 	"github.com/9seconds/mtg/telegram"
9
 	"github.com/9seconds/mtg/wrappers"
10
 	"github.com/9seconds/mtg/wrappers"
10
 )
11
 )
11
 
12
 
12
-func TelegramProtocol(req *protocol.TelegramRequest) (wrappers.Wrap, error) {
13
-	conn, err := telegram.Middle.Dial(req.Ctx,
14
-		req.Cancel,
15
-		req.ClientProtocol.DC(),
13
+func TelegramProtocol(req *protocol.TelegramRequest) (conntypes.PacketReadWriteCloser, error) {
14
+	conn, err := telegram.Middle.Dial(req.ClientProtocol.DC(),
16
 		req.ClientProtocol.ConnectionProtocol())
15
 		req.ClientProtocol.ConnectionProtocol())
17
 	if err != nil {
16
 	if err != nil {
18
 		return nil, fmt.Errorf("cannot connect to telegram: %w", err)
17
 		return nil, fmt.Errorf("cannot connect to telegram: %w", err)
42
 	return frameConn, nil
41
 	return frameConn, nil
43
 }
42
 }
44
 
43
 
45
-func doRPCNonceRequest(conn wrappers.PacketWriter) (*rpc.NonceRequest, error) {
44
+func doRPCNonceRequest(conn conntypes.PacketWriter) (*rpc.NonceRequest, error) {
46
 	rpcNonceReq, err := rpc.NewNonceRequest(telegram.Middle.Secret())
45
 	rpcNonceReq, err := rpc.NewNonceRequest(telegram.Middle.Secret())
47
 	if err != nil {
46
 	if err != nil {
48
 		panic(err)
47
 		panic(err)
54
 	return rpcNonceReq, nil
53
 	return rpcNonceReq, nil
55
 }
54
 }
56
 
55
 
57
-func getRPCNonceResponse(conn wrappers.PacketReader, req *rpc.NonceRequest) (*rpc.NonceResponse, error) {
56
+func getRPCNonceResponse(conn conntypes.PacketReader, req *rpc.NonceRequest) (*rpc.NonceResponse, error) {
58
 	packet, err := conn.Read()
57
 	packet, err := conn.Read()
59
 	if err != nil {
58
 	if err != nil {
60
 		return nil, fmt.Errorf("cannot read from connection: %w", err)
59
 		return nil, fmt.Errorf("cannot read from connection: %w", err)
71
 	return resp, nil
70
 	return resp, nil
72
 }
71
 }
73
 
72
 
74
-func doRPCHandshakeRequest(conn wrappers.PacketWriter) error {
73
+func doRPCHandshakeRequest(conn conntypes.PacketWriter) error {
75
 	if err := conn.Write(rpc.HandshakeRequest); err != nil {
74
 	if err := conn.Write(rpc.HandshakeRequest); err != nil {
76
 		return fmt.Errorf("cannot make a request: %w", err)
75
 		return fmt.Errorf("cannot make a request: %w", err)
77
 	}
76
 	}
78
 	return nil
77
 	return nil
79
 }
78
 }
80
 
79
 
81
-func getRPCHandshakeResponse(conn wrappers.PacketReader) error {
80
+func getRPCHandshakeResponse(conn conntypes.PacketReader) error {
82
 	packet, err := conn.Read()
81
 	packet, err := conn.Read()
83
 	if err != nil {
82
 	if err != nil {
84
 		return fmt.Errorf("cannot read a response: %w", err)
83
 		return fmt.Errorf("cannot read a response: %w", err)

+ 3
- 3
obfuscated2/client_protocol.go ファイルの表示

37
 	return c.dc
37
 	return c.dc
38
 }
38
 }
39
 
39
 
40
-func (c *ClientProtocol) Handshake(socket wrappers.StreamReadWriteCloser) (wrappers.StreamReadWriteCloser, error) {
40
+func (c *ClientProtocol) Handshake(socket conntypes.StreamReadWriteCloser) (conntypes.StreamReadWriteCloser, error) {
41
 	fm, err := c.ReadFrame(socket)
41
 	fm, err := c.ReadFrame(socket)
42
 	if err != nil {
42
 	if err != nil {
43
 		return nil, fmt.Errorf("cannot make a client handshake: %w", err)
43
 		return nil, fmt.Errorf("cannot make a client handshake: %w", err)
88
 	return wrappers.NewObfuscated2(socket, encryptor, decryptor), nil
88
 	return wrappers.NewObfuscated2(socket, encryptor, decryptor), nil
89
 }
89
 }
90
 
90
 
91
-func (c *ClientProtocol) ReadFrame(socket wrappers.StreamReader) (fm Frame, err error) {
91
+func (c *ClientProtocol) ReadFrame(socket conntypes.StreamReader) (fm Frame, err error) {
92
 	if _, err = io.ReadFull(handshakeReader{socket}, fm.Bytes()); err != nil {
92
 	if _, err = io.ReadFull(handshakeReader{socket}, fm.Bytes()); err != nil {
93
 		err = fmt.Errorf("cannot extract obfuscated2 frame: %w", err)
93
 		err = fmt.Errorf("cannot extract obfuscated2 frame: %w", err)
94
 	}
94
 	}
96
 }
96
 }
97
 
97
 
98
 type handshakeReader struct {
98
 type handshakeReader struct {
99
-	parent wrappers.StreamReader
99
+	parent conntypes.StreamReader
100
 }
100
 }
101
 
101
 
102
 func (h handshakeReader) Read(p []byte) (int, error) {
102
 func (h handshakeReader) Read(p []byte) (int, error) {

+ 7
- 6
obfuscated2/telegram_protocol.go ファイルの表示

4
 	"crypto/rand"
4
 	"crypto/rand"
5
 	"fmt"
5
 	"fmt"
6
 
6
 
7
+	"github.com/9seconds/mtg/conntypes"
7
 	"github.com/9seconds/mtg/protocol"
8
 	"github.com/9seconds/mtg/protocol"
8
 	"github.com/9seconds/mtg/telegram"
9
 	"github.com/9seconds/mtg/telegram"
9
 	"github.com/9seconds/mtg/utils"
10
 	"github.com/9seconds/mtg/utils"
10
 	"github.com/9seconds/mtg/wrappers"
11
 	"github.com/9seconds/mtg/wrappers"
11
 )
12
 )
12
 
13
 
13
-func TelegramProtocol(req *protocol.TelegramRequest) (wrappers.Wrap, error) {
14
-	socket, err := telegram.Direct.Dial(req.Ctx,
15
-		req.Cancel,
16
-		req.ClientProtocol.DC(),
14
+func TelegramProtocol(req *protocol.TelegramRequest) (conntypes.StreamReadWriteCloser, error) {
15
+	conn, err := telegram.Direct.Dial(req.ClientProtocol.DC(),
17
 		req.ClientProtocol.ConnectionProtocol())
16
 		req.ClientProtocol.ConnectionProtocol())
18
 	if err != nil {
17
 	if err != nil {
19
 		return nil, fmt.Errorf("cannot dial to telegram: %w", err)
18
 		return nil, fmt.Errorf("cannot dial to telegram: %w", err)
20
 	}
19
 	}
20
+	conn = wrappers.NewTimeout(conn)
21
+	conn = wrappers.NewCtx(req.Ctx, req.Cancel, conn)
21
 	fm := generateFrame(req.ClientProtocol)
22
 	fm := generateFrame(req.ClientProtocol)
22
 	data := fm.Bytes()
23
 	data := fm.Bytes()
23
 
24
 
30
 	encryptor.XORKeyStream(data, data)
31
 	encryptor.XORKeyStream(data, data)
31
 	copy(data[:frameOffsetIV], copyFrame[:frameOffsetIV])
32
 	copy(data[:frameOffsetIV], copyFrame[:frameOffsetIV])
32
 
33
 
33
-	if _, err := socket.Write(data); err != nil {
34
+	if _, err := conn.Write(data); err != nil {
34
 		return nil, fmt.Errorf("cannot write handshake frame to telegram: %w", err)
35
 		return nil, fmt.Errorf("cannot write handshake frame to telegram: %w", err)
35
 	}
36
 	}
36
 
37
 
37
-	return wrappers.NewObfuscated2(socket, encryptor, decryptor), nil
38
+	return wrappers.NewObfuscated2(conn, encryptor, decryptor), nil
38
 }
39
 }
39
 
40
 
40
 func generateFrame(cp protocol.ClientProtocol) (fm Frame) {
41
 func generateFrame(cp protocol.ClientProtocol) (fm Frame) {

+ 2
- 6
protocol/interfaces.go ファイルの表示

1
 package protocol
1
 package protocol
2
 
2
 
3
-import (
4
-	"github.com/9seconds/mtg/conntypes"
5
-	"github.com/9seconds/mtg/wrappers"
6
-)
3
+import "github.com/9seconds/mtg/conntypes"
7
 
4
 
8
 type ClientProtocol interface {
5
 type ClientProtocol interface {
9
-	Handshake(wrappers.StreamReadWriteCloser) (wrappers.StreamReadWriteCloser, error)
6
+	Handshake(conntypes.StreamReadWriteCloser) (conntypes.StreamReadWriteCloser, error)
10
 	ConnectionType() conntypes.ConnectionType
7
 	ConnectionType() conntypes.ConnectionType
11
 	ConnectionProtocol() conntypes.ConnectionProtocol
8
 	ConnectionProtocol() conntypes.ConnectionProtocol
12
 	DC() conntypes.DC
9
 	DC() conntypes.DC
13
 }
10
 }
14
 
11
 
15
-type TelegramProtocol func(*TelegramRequest) (wrappers.Wrap, error)
16
 type ClientProtocolMaker func() ClientProtocol
12
 type ClientProtocolMaker func() ClientProtocol

+ 1
- 2
protocol/request.go ファイルの表示

6
 	"go.uber.org/zap"
6
 	"go.uber.org/zap"
7
 
7
 
8
 	"github.com/9seconds/mtg/conntypes"
8
 	"github.com/9seconds/mtg/conntypes"
9
-	"github.com/9seconds/mtg/wrappers"
10
 )
9
 )
11
 
10
 
12
 type TelegramRequest struct {
11
 type TelegramRequest struct {
13
 	Logger         *zap.SugaredLogger
12
 	Logger         *zap.SugaredLogger
14
-	ClientConn     wrappers.StreamReadWriteCloser
13
+	ClientConn     conntypes.StreamReadWriteCloser
15
 	ConnID         conntypes.ConnID
14
 	ConnID         conntypes.ConnID
16
 	Ctx            context.Context
15
 	Ctx            context.Context
17
 	Cancel         context.CancelFunc
16
 	Cancel         context.CancelFunc

+ 10
- 9
proxy/proxy.go ファイルの表示

63
 	ctx, cancel := context.WithCancel(p.Context)
63
 	ctx, cancel := context.WithCancel(p.Context)
64
 	defer cancel()
64
 	defer cancel()
65
 
65
 
66
-	wrappedConn := wrappers.NewClientConn(ctx, cancel, conn, connID)
67
-	wrappedConn = wrappers.NewTraffic(wrappedConn)
68
-	defer wrappedConn.Close()
66
+	clientConn := wrappers.NewClientConn(conn, connID)
67
+	clientConn = wrappers.NewCtx(ctx, cancel, clientConn)
68
+	clientConn = wrappers.NewTimeout(clientConn)
69
+	clientConn = wrappers.NewTraffic(clientConn)
70
+	defer clientConn.Close()
69
 
71
 
70
 	clientProtocol := p.ClientProtocolMaker()
72
 	clientProtocol := p.ClientProtocolMaker()
71
-	wrappedConn, err := clientProtocol.Handshake(wrappedConn)
73
+	clientConn, err := clientProtocol.Handshake(clientConn)
72
 	if err != nil {
74
 	if err != nil {
73
 		logger.Warnw("Cannot perform client handshake", "error", err)
75
 		logger.Warnw("Cannot perform client handshake", "error", err)
74
 		return
76
 		return
75
 	}
77
 	}
76
-	defer wrappedConn.Close()
77
 
78
 
78
-	stats.S.ClientConnected(clientProtocol.ConnectionType(), wrappedConn.RemoteAddr())
79
-	defer stats.S.ClientDisconnected(clientProtocol.ConnectionType(), wrappedConn.RemoteAddr())
79
+	stats.S.ClientConnected(clientProtocol.ConnectionType(), clientConn.RemoteAddr())
80
+	defer stats.S.ClientDisconnected(clientProtocol.ConnectionType(), clientConn.RemoteAddr())
80
 	logger.Infow("Client connected", "addr", conn.RemoteAddr())
81
 	logger.Infow("Client connected", "addr", conn.RemoteAddr())
81
 
82
 
82
 	req := &protocol.TelegramRequest{
83
 	req := &protocol.TelegramRequest{
83
 		Logger:         logger,
84
 		Logger:         logger,
84
-		ClientConn:     wrappedConn,
85
+		ClientConn:     clientConn,
85
 		ConnID:         connID,
86
 		ConnID:         connID,
86
 		Ctx:            ctx,
87
 		Ctx:            ctx,
87
 		Cancel:         cancel,
88
 		Cancel:         cancel,
102
 	if err != nil {
103
 	if err != nil {
103
 		return err
104
 		return err
104
 	}
105
 	}
105
-	telegramConn := telegramConnRaw.(wrappers.StreamReadWriteCloser)
106
+	telegramConn := telegramConnRaw.(conntypes.StreamReadWriteCloser)
106
 	defer telegramConn.Close()
107
 	defer telegramConn.Close()
107
 
108
 
108
 	wg := &sync.WaitGroup{}
109
 	wg := &sync.WaitGroup{}

+ 16
- 22
telegram/base.go ファイルの表示

1
 package telegram
1
 package telegram
2
 
2
 
3
 import (
3
 import (
4
-	"context"
5
 	"fmt"
4
 	"fmt"
6
 	"math/rand"
5
 	"math/rand"
7
 	"net"
6
 	"net"
28
 	return b.secret
27
 	return b.secret
29
 }
28
 }
30
 
29
 
31
-func (b *baseTelegram) dialToAddress(ctx context.Context,
32
-	cancel context.CancelFunc,
33
-	addr string) (wrappers.StreamReadWriteCloser, error) {
34
-	conn, err := b.dialer.Dial("tcp", addr)
35
-	if err != nil {
36
-		return nil, fmt.Errorf("dial has failed: %w", err)
37
-	}
38
-
39
-	if err := utils.InitTCP(conn); err != nil {
40
-		return nil, fmt.Errorf("cannot initialize tcp socket: %w", err)
41
-	}
42
-
43
-	return wrappers.NewTelegramConn(ctx, cancel, conn), nil
44
-}
45
-
46
-func (b *baseTelegram) dial(ctx context.Context,
47
-	cancel context.CancelFunc,
48
-	dc conntypes.DC,
49
-	protocol conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
30
+func (b *baseTelegram) dial(dc conntypes.DC,
31
+	protocol conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error) {
50
 	addr := ""
32
 	addr := ""
51
 
33
 
52
 	switch protocol {
34
 	switch protocol {
56
 		addr = b.chooseAddress(b.v6Addresses, dc, b.V6DefaultDC)
38
 		addr = b.chooseAddress(b.v6Addresses, dc, b.V6DefaultDC)
57
 	}
39
 	}
58
 
40
 
59
-	return b.dialToAddress(ctx, cancel, addr)
41
+	conn, err := b.dialer.Dial("tcp", addr)
42
+	if err != nil {
43
+		return nil, fmt.Errorf("dial has failed: %w", err)
44
+	}
45
+
46
+	if err := utils.InitTCP(conn); err != nil {
47
+		return nil, fmt.Errorf("cannot initialize tcp socket: %w", err)
48
+	}
49
+
50
+	return wrappers.NewTelegramConn(conn), nil
60
 }
51
 }
61
 
52
 
62
 func (b *baseTelegram) chooseAddress(addresses map[conntypes.DC][]string,
53
 func (b *baseTelegram) chooseAddress(addresses map[conntypes.DC][]string,
66
 		addrs, _ = addresses[defaultDC]
57
 		addrs, _ = addresses[defaultDC]
67
 	}
58
 	}
68
 
59
 
69
-	if len(addrs) > 0 {
60
+	switch {
61
+	case len(addrs) == 1:
62
+		return addrs[0]
63
+	case len(addrs) > 1:
70
 		return addrs[rand.Intn(len(addrs))]
64
 		return addrs[rand.Intn(len(addrs))]
71
 	}
65
 	}
72
 
66
 

+ 11
- 19
telegram/direct.go ファイルの表示

1
 package telegram
1
 package telegram
2
 
2
 
3
 import (
3
 import (
4
-	"context"
5
 	"net"
4
 	"net"
6
 
5
 
7
 	"github.com/9seconds/mtg/conntypes"
6
 	"github.com/9seconds/mtg/conntypes"
8
-	"github.com/9seconds/mtg/wrappers"
9
 )
7
 )
10
 
8
 
11
-var Direct = newDirectTelegram()
12
-
13
 const (
9
 const (
14
 	directV4DefaultIdx conntypes.DC = 1
10
 	directV4DefaultIdx conntypes.DC = 1
15
 	directV6DefaultIdx conntypes.DC = 1
11
 	directV6DefaultIdx conntypes.DC = 1
36
 	baseTelegram
32
 	baseTelegram
37
 }
33
 }
38
 
34
 
39
-func (d *directTelegram) Dial(ctx context.Context,
40
-	cancel context.CancelFunc,
41
-	dc conntypes.DC,
42
-	protocol conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
35
+func (d *directTelegram) Dial(dc conntypes.DC,
36
+	protocol conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error) {
43
 	switch {
37
 	switch {
44
 	case dc < 0:
38
 	case dc < 0:
45
 		dc = -dc
39
 		dc = -dc
47
 		dc = conntypes.DCDefaultIdx
41
 		dc = conntypes.DCDefaultIdx
48
 	}
42
 	}
49
 
43
 
50
-	return d.baseTelegram.dial(ctx, cancel, dc-1, protocol)
44
+	return d.baseTelegram.dial(dc-1, protocol)
51
 }
45
 }
52
 
46
 
53
-func newDirectTelegram() Telegram {
54
-	return &directTelegram{
55
-		baseTelegram: baseTelegram{
56
-			dialer:      net.Dialer{Timeout: telegramDialTimeout},
57
-			v4DefaultDC: directV4DefaultIdx,
58
-			V6DefaultDC: directV6DefaultIdx,
59
-			v4Addresses: directV4Addresses,
60
-			v6Addresses: directV6Addresses,
61
-		},
62
-	}
47
+var Direct = &directTelegram{
48
+	baseTelegram: baseTelegram{
49
+		dialer:      net.Dialer{Timeout: telegramDialTimeout},
50
+		v4DefaultDC: directV4DefaultIdx,
51
+		V6DefaultDC: directV6DefaultIdx,
52
+		v4Addresses: directV4Addresses,
53
+		v6Addresses: directV6Addresses,
54
+	},
63
 }
55
 }

+ 2
- 10
telegram/interfaces.go ファイルの表示

1
 package telegram
1
 package telegram
2
 
2
 
3
-import (
4
-	"context"
5
-
6
-	"github.com/9seconds/mtg/conntypes"
7
-	"github.com/9seconds/mtg/wrappers"
8
-)
3
+import  "github.com/9seconds/mtg/conntypes"
9
 
4
 
10
 type Telegram interface {
5
 type Telegram interface {
11
-	Dial(context.Context,
12
-		context.CancelFunc,
13
-		conntypes.DC,
14
-		conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error)
6
+	Dial(conntypes.DC, conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error)
15
 	Secret() []byte
7
 	Secret() []byte
16
 }
8
 }

+ 3
- 7
telegram/middle.go ファイルの表示

1
 package telegram
1
 package telegram
2
 
2
 
3
 import (
3
 import (
4
-	"context"
5
 	"fmt"
4
 	"fmt"
6
 	"net"
5
 	"net"
7
 	"sync"
6
 	"sync"
11
 
10
 
12
 	"github.com/9seconds/mtg/conntypes"
11
 	"github.com/9seconds/mtg/conntypes"
13
 	"github.com/9seconds/mtg/telegram/api"
12
 	"github.com/9seconds/mtg/telegram/api"
14
-	"github.com/9seconds/mtg/wrappers"
15
 )
13
 )
16
 
14
 
17
 const middleTelegramBackgroundUpdateEvery = time.Hour
15
 const middleTelegramBackgroundUpdateEvery = time.Hour
67
 	}
65
 	}
68
 }
66
 }
69
 
67
 
70
-func (m *middleTelegram) Dial(ctx context.Context,
71
-	cancel context.CancelFunc,
72
-	dc conntypes.DC,
73
-	protocol conntypes.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
68
+func (m *middleTelegram) Dial(dc conntypes.DC,
69
+	protocol conntypes.ConnectionProtocol) (conntypes.StreamReadWriteCloser, error) {
74
 	if dc == 0 {
70
 	if dc == 0 {
75
 		dc = conntypes.DCDefaultIdx
71
 		dc = conntypes.DCDefaultIdx
76
 	}
72
 	}
78
 	m.mutex.RLock()
74
 	m.mutex.RLock()
79
 	defer m.mutex.RUnlock()
75
 	defer m.mutex.RUnlock()
80
 
76
 
81
-	return m.baseTelegram.dial(ctx, cancel, dc, protocol)
77
+	return m.baseTelegram.dial(dc, protocol)
82
 }
78
 }
83
 
79
 
84
 func MiddleInit() {
80
 func MiddleInit() {

+ 10
- 6
wrappers/blockcipher.go ファイルの表示

10
 	"time"
10
 	"time"
11
 
11
 
12
 	"go.uber.org/zap"
12
 	"go.uber.org/zap"
13
+
14
+	"github.com/9seconds/mtg/conntypes"
13
 )
15
 )
14
 
16
 
15
 const blockCipherReadCurrentDataBufferSize = 1024 + 1 // +1 because telegram operates with blocks mod 4
17
 const blockCipherReadCurrentDataBufferSize = 1024 + 1 // +1 because telegram operates with blocks mod 4
17
 type wrapperBlockCipher struct {
19
 type wrapperBlockCipher struct {
18
 	buf bytes.Buffer
20
 	buf bytes.Buffer
19
 
21
 
20
-	parent    StreamReadWriteCloser
22
+	parent    conntypes.StreamReadWriteCloser
21
 	encryptor cipher.BlockMode
23
 	encryptor cipher.BlockMode
22
 	decryptor cipher.BlockMode
24
 	decryptor cipher.BlockMode
23
 }
25
 }
47
 	return w.read(p, readAllTimeout(timeout))
49
 	return w.read(p, readAllTimeout(timeout))
48
 }
50
 }
49
 
51
 
50
-func (w *wrapperBlockCipher) read(p []byte, reader func(StreamReadWriteCloser) ([]byte, error)) (int, error) {
52
+func (w *wrapperBlockCipher) read(p []byte,
53
+	reader func(conntypes.StreamReadWriteCloser) ([]byte, error)) (int, error) {
51
 	if w.buf.Len() > 0 {
54
 	if w.buf.Len() > 0 {
52
 		return w.flush(p)
55
 		return w.flush(p)
53
 	}
56
 	}
90
 	return encrypted, nil
93
 	return encrypted, nil
91
 }
94
 }
92
 
95
 
93
-func readAll(src StreamReadWriteCloser) (rv []byte, err error) {
96
+func readAll(src conntypes.StreamReadWriteCloser) (rv []byte, err error) {
94
 	buf := make([]byte, blockCipherReadCurrentDataBufferSize)
97
 	buf := make([]byte, blockCipherReadCurrentDataBufferSize)
95
 	n := blockCipherReadCurrentDataBufferSize
98
 	n := blockCipherReadCurrentDataBufferSize
96
 
99
 
105
 	return rv, nil
108
 	return rv, nil
106
 }
109
 }
107
 
110
 
108
-func readAllTimeout(timeout time.Duration) func(StreamReadWriteCloser) ([]byte, error) {
109
-	return func(src StreamReadWriteCloser) (rv []byte, err error) {
111
+func readAllTimeout(timeout time.Duration) func(conntypes.StreamReadWriteCloser) ([]byte, error) {
112
+	return func(src conntypes.StreamReadWriteCloser) (rv []byte, err error) {
110
 		tmo := timeout
113
 		tmo := timeout
111
 		buf := make([]byte, blockCipherReadCurrentDataBufferSize)
114
 		buf := make([]byte, blockCipherReadCurrentDataBufferSize)
112
 		n := blockCipherReadCurrentDataBufferSize
115
 		n := blockCipherReadCurrentDataBufferSize
148
 	return w.parent.RemoteAddr()
151
 	return w.parent.RemoteAddr()
149
 }
152
 }
150
 
153
 
151
-func newBlockCipher(parent StreamReadWriteCloser, encryptor, decryptor cipher.BlockMode) StreamReadWriteCloser {
154
+func newBlockCipher(parent conntypes.StreamReadWriteCloser,
155
+	encryptor, decryptor cipher.BlockMode) conntypes.StreamReadWriteCloser {
152
 	return &wrapperBlockCipher{
156
 	return &wrapperBlockCipher{
153
 		parent:    parent,
157
 		parent:    parent,
154
 		encryptor: encryptor,
158
 		encryptor: encryptor,

+ 29
- 61
wrappers/conn.go ファイルの表示

1
 package wrappers
1
 package wrappers
2
 
2
 
3
 import (
3
 import (
4
-	"context"
5
 	"fmt"
4
 	"fmt"
6
 	"net"
5
 	"net"
7
 	"time"
6
 	"time"
19
 	connPurposeTelegram
18
 	connPurposeTelegram
20
 )
19
 )
21
 
20
 
22
-const (
23
-	connTimeoutRead  = 2 * time.Minute
24
-	connTimeoutWrite = 2 * time.Minute
25
-)
26
-
27
 type wrapperConn struct {
21
 type wrapperConn struct {
28
 	parent     net.Conn
22
 	parent     net.Conn
29
-	ctx        context.Context
30
-	cancel     context.CancelFunc
31
 	connID     conntypes.ConnID
23
 	connID     conntypes.ConnID
32
 	logger     *zap.SugaredLogger
24
 	logger     *zap.SugaredLogger
33
 	localAddr  *net.TCPAddr
25
 	localAddr  *net.TCPAddr
35
 }
27
 }
36
 
28
 
37
 func (w *wrapperConn) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
29
 func (w *wrapperConn) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
38
-	select {
39
-	case <-w.ctx.Done():
30
+	if err := w.parent.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
40
 		w.Close()
31
 		w.Close()
41
-		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
42
-
43
-	default:
44
-		if err := w.parent.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
45
-			w.Close() // nolint: gosec
46
-			return 0, fmt.Errorf("cannot set write deadline to the socket: %w", err)
47
-		}
48
-
49
-		n, err := w.parent.Write(p)
50
-		w.logger.Debugw("Write to stream", "bytes", n, "error", err)
51
-		if err != nil {
52
-			w.Close() // nolint: gosec
53
-		}
54
-
55
-		return n, err
32
+		return 0, fmt.Errorf("cannot set write deadline to the socket: %w", err)
56
 	}
33
 	}
34
+
35
+	return w.Write(p)
57
 }
36
 }
58
 
37
 
59
 func (w *wrapperConn) Write(p []byte) (int, error) {
38
 func (w *wrapperConn) Write(p []byte) (int, error) {
60
-	return w.WriteTimeout(p, connTimeoutWrite)
39
+	n, err := w.parent.Write(p)
40
+	w.logger.Debugw("write to stream", "bytes", n, "error", err)
41
+	if err != nil {
42
+		w.Close() // nolint: gosec
43
+	}
44
+
45
+	return n, err
61
 }
46
 }
62
 
47
 
63
 func (w *wrapperConn) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
48
 func (w *wrapperConn) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
64
-	select {
65
-	case <-w.ctx.Done():
49
+	if err := w.parent.SetReadDeadline(time.Now().Add(timeout)); err != nil {
66
 		w.Close()
50
 		w.Close()
67
-		return 0, fmt.Errorf("cannot read because context was closed: %w", w.ctx.Err())
68
-
69
-	default:
70
-		if err := w.parent.SetReadDeadline(time.Now().Add(timeout)); err != nil {
71
-			w.Close()
72
-			return 0, fmt.Errorf("cannot set read deadline to the socket: %w", err)
73
-		}
74
-
75
-		n, err := w.parent.Read(p)
76
-		w.logger.Debugw("Read from stream", "bytes", n, "error", err)
77
-		if err != nil {
78
-			w.Close()
79
-		}
80
-
81
-		return n, err
51
+		return 0, fmt.Errorf("cannot set read deadline to the socket: %w", err)
82
 	}
52
 	}
53
+
54
+	return w.Read(p)
83
 }
55
 }
84
 
56
 
85
 func (w *wrapperConn) Read(p []byte) (int, error) {
57
 func (w *wrapperConn) Read(p []byte) (int, error) {
86
-	return w.ReadTimeout(p, connTimeoutRead)
58
+	n, err := w.parent.Read(p)
59
+	w.logger.Debugw("Read from stream", "bytes", n, "error", err)
60
+	if err != nil {
61
+		w.Close()
62
+	}
63
+
64
+	return n, err
87
 }
65
 }
88
 
66
 
89
 func (w *wrapperConn) Close() error {
67
 func (w *wrapperConn) Close() error {
90
 	w.logger.Debugw("Close connection")
68
 	w.logger.Debugw("Close connection")
91
-	w.cancel()
92
-
93
 	return w.parent.Close()
69
 	return w.parent.Close()
94
 }
70
 }
95
 
71
 
109
 	return w.remoteAddr
85
 	return w.remoteAddr
110
 }
86
 }
111
 
87
 
112
-func newConn(ctx context.Context,
113
-	cancel context.CancelFunc,
114
-	parent net.Conn,
88
+func newConn(parent net.Conn,
115
 	connID conntypes.ConnID,
89
 	connID conntypes.ConnID,
116
-	purpose connPurpose) StreamReadWriteCloser {
90
+	purpose connPurpose) conntypes.StreamReadWriteCloser {
117
 	localAddr := *parent.LocalAddr().(*net.TCPAddr)
91
 	localAddr := *parent.LocalAddr().(*net.TCPAddr)
118
 
92
 
119
 	if parent.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
93
 	if parent.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
135
 
109
 
136
 	return &wrapperConn{
110
 	return &wrapperConn{
137
 		parent:     parent,
111
 		parent:     parent,
138
-		ctx:        ctx,
139
-		cancel:     cancel,
140
 		connID:     connID,
112
 		connID:     connID,
141
 		logger:     logger,
113
 		logger:     logger,
142
 		remoteAddr: parent.RemoteAddr().(*net.TCPAddr),
114
 		remoteAddr: parent.RemoteAddr().(*net.TCPAddr),
144
 	}
116
 	}
145
 }
117
 }
146
 
118
 
147
-func NewClientConn(ctx context.Context,
148
-	cancel context.CancelFunc,
149
-	parent net.Conn,
150
-	connID conntypes.ConnID) StreamReadWriteCloser {
151
-	return newConn(ctx, cancel, parent, connID, connPurposeClient)
119
+func NewClientConn(parent net.Conn,
120
+	connID conntypes.ConnID) conntypes.StreamReadWriteCloser {
121
+	return newConn(parent, connID, connPurposeClient)
152
 }
122
 }
153
 
123
 
154
-func NewTelegramConn(ctx context.Context,
155
-	cancel context.CancelFunc,
156
-	parent net.Conn) StreamReadWriteCloser {
157
-	return newConn(ctx, cancel, parent, conntypes.ConnID{}, connPurposeTelegram)
124
+func NewTelegramConn(parent net.Conn) conntypes.StreamReadWriteCloser {
125
+	return newConn(parent, conntypes.ConnID{}, connPurposeTelegram)
158
 }
126
 }

+ 89
- 0
wrappers/ctx.go ファイルの表示

1
+package wrappers
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"net"
7
+	"time"
8
+
9
+	"go.uber.org/zap"
10
+
11
+	"github.com/9seconds/mtg/conntypes"
12
+)
13
+
14
+type wrapperCtx struct {
15
+	parent conntypes.StreamReadWriteCloser
16
+	ctx    context.Context
17
+	cancel context.CancelFunc
18
+}
19
+
20
+func (w *wrapperCtx) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
21
+	select {
22
+	case <-w.ctx.Done():
23
+		w.Close()
24
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
25
+	default:
26
+		return w.parent.WriteTimeout(p, timeout)
27
+	}
28
+}
29
+
30
+func (w *wrapperCtx) Write(p []byte) (int, error) {
31
+	select {
32
+	case <-w.ctx.Done():
33
+		w.Close()
34
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
35
+	default:
36
+		return w.parent.Write(p)
37
+	}
38
+}
39
+
40
+func (w *wrapperCtx) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
41
+	select {
42
+	case <-w.ctx.Done():
43
+		w.Close()
44
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
45
+	default:
46
+		return w.parent.ReadTimeout(p, timeout)
47
+	}
48
+}
49
+
50
+func (w *wrapperCtx) Read(p []byte) (int, error) {
51
+	select {
52
+	case <-w.ctx.Done():
53
+		w.Close()
54
+		return 0, fmt.Errorf("cannot write because context was closed: %w", w.ctx.Err())
55
+	default:
56
+		return w.parent.Read(p)
57
+	}
58
+}
59
+
60
+func (w *wrapperCtx) Close() error {
61
+	w.cancel()
62
+	return w.parent.Close()
63
+}
64
+
65
+func (w *wrapperCtx) Conn() net.Conn {
66
+	return w.parent.Conn()
67
+}
68
+
69
+func (w *wrapperCtx) Logger() *zap.SugaredLogger {
70
+	return w.parent.Logger().Named("ctx")
71
+}
72
+
73
+func (w *wrapperCtx) LocalAddr() *net.TCPAddr {
74
+	return w.parent.LocalAddr()
75
+}
76
+
77
+func (w *wrapperCtx) RemoteAddr() *net.TCPAddr {
78
+	return w.parent.RemoteAddr()
79
+}
80
+
81
+func NewCtx(ctx context.Context,
82
+	cancel context.CancelFunc,
83
+	parent conntypes.StreamReadWriteCloser) conntypes.StreamReadWriteCloser {
84
+	return &wrapperCtx{
85
+		parent: parent,
86
+		ctx:    ctx,
87
+		cancel: cancel,
88
+	}
89
+}

+ 3
- 2
wrappers/mtproto_cipher.go ファイルの表示

9
 	"encoding/binary"
9
 	"encoding/binary"
10
 	"net"
10
 	"net"
11
 
11
 
12
+	"github.com/9seconds/mtg/conntypes"
12
 	"github.com/9seconds/mtg/mtproto/rpc"
13
 	"github.com/9seconds/mtg/mtproto/rpc"
13
 	"github.com/9seconds/mtg/utils"
14
 	"github.com/9seconds/mtg/utils"
14
 )
15
 )
22
 
23
 
23
 var mtprotoEmptyIP = [4]byte{0x00, 0x00, 0x00, 0x00}
24
 var mtprotoEmptyIP = [4]byte{0x00, 0x00, 0x00, 0x00}
24
 
25
 
25
-func NewMiddleProxyCipher(parent StreamReadWriteCloser,
26
+func NewMiddleProxyCipher(parent conntypes.StreamReadWriteCloser,
26
 	req *rpc.NonceRequest,
27
 	req *rpc.NonceRequest,
27
 	resp *rpc.NonceResponse,
28
 	resp *rpc.NonceResponse,
28
-	secret []byte) StreamReadWriteCloser {
29
+	secret []byte) conntypes.StreamReadWriteCloser {
29
 	localAddr := parent.LocalAddr()
30
 	localAddr := parent.LocalAddr()
30
 	remoteAddr := parent.RemoteAddr()
31
 	remoteAddr := parent.RemoteAddr()
31
 
32
 

+ 6
- 4
wrappers/mtproto_frame.go ファイルの表示

11
 	"net"
11
 	"net"
12
 
12
 
13
 	"go.uber.org/zap"
13
 	"go.uber.org/zap"
14
+
15
+	"github.com/9seconds/mtg/conntypes"
14
 )
16
 )
15
 
17
 
16
 const (
18
 const (
33
 // PADDING is custom padding schema to complete frame length to such that
35
 // PADDING is custom padding schema to complete frame length to such that
34
 //    len(frame) % 16 == 0
36
 //    len(frame) % 16 == 0
35
 type wrapperMtprotoFrame struct {
37
 type wrapperMtprotoFrame struct {
36
-	parent     StreamReadWriteCloser
38
+	parent     conntypes.StreamReadWriteCloser
37
 	logger     *zap.SugaredLogger
39
 	logger     *zap.SugaredLogger
38
 	readSeqNo  int32
40
 	readSeqNo  int32
39
 	writeSeqNo int32
41
 	writeSeqNo int32
40
 }
42
 }
41
 
43
 
42
-func (w *wrapperMtprotoFrame) Read() (Packet, error) {
44
+func (w *wrapperMtprotoFrame) Read() (conntypes.Packet, error) {
43
 	buf := &bytes.Buffer{}
45
 	buf := &bytes.Buffer{}
44
 	sum := crc32.NewIEEE()
46
 	sum := crc32.NewIEEE()
45
 	writer := io.MultiWriter(buf, sum)
47
 	writer := io.MultiWriter(buf, sum)
101
 	return data, nil
103
 	return data, nil
102
 }
104
 }
103
 
105
 
104
-func (w *wrapperMtprotoFrame) Write(p Packet) error {
106
+func (w *wrapperMtprotoFrame) Write(p conntypes.Packet) error {
105
 	messageLength := 4 + 4 + len(p) + 4
107
 	messageLength := 4 + 4 + len(p) + 4
106
 	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
108
 	paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize
107
 
109
 
149
 	return w.parent.RemoteAddr()
151
 	return w.parent.RemoteAddr()
150
 }
152
 }
151
 
153
 
152
-func NewMtprotoFrame(parent StreamReadWriteCloser, seqNo int32) PacketReadWriteCloser {
154
+func NewMtprotoFrame(parent conntypes.StreamReadWriteCloser, seqNo int32) conntypes.PacketReadWriteCloser {
153
 	return &wrapperMtprotoFrame{
155
 	return &wrapperMtprotoFrame{
154
 		parent:     parent,
156
 		parent:     parent,
155
 		logger:     parent.Logger().Named("mtproto-frame"),
157
 		logger:     parent.Logger().Named("mtproto-frame"),

+ 5
- 2
wrappers/obfuscated2.go ファイルの表示

7
 	"time"
7
 	"time"
8
 
8
 
9
 	"go.uber.org/zap"
9
 	"go.uber.org/zap"
10
+
11
+	"github.com/9seconds/mtg/conntypes"
10
 )
12
 )
11
 
13
 
12
 type wrapperObfuscated2 struct {
14
 type wrapperObfuscated2 struct {
13
 	encryptor cipher.Stream
15
 	encryptor cipher.Stream
14
 	decryptor cipher.Stream
16
 	decryptor cipher.Stream
15
-	parent    StreamReadWriteCloser
17
+	parent    conntypes.StreamReadWriteCloser
16
 }
18
 }
17
 
19
 
18
 func (w *wrapperObfuscated2) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
20
 func (w *wrapperObfuscated2) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
71
 	return w.parent.Close()
73
 	return w.parent.Close()
72
 }
74
 }
73
 
75
 
74
-func NewObfuscated2(socket StreamReadWriteCloser, encryptor, decryptor cipher.Stream) StreamReadWriteCloser {
76
+func NewObfuscated2(socket conntypes.StreamReadWriteCloser,
77
+	encryptor, decryptor cipher.Stream) conntypes.StreamReadWriteCloser {
75
 	return &wrapperObfuscated2{
78
 	return &wrapperObfuscated2{
76
 		parent:    socket,
79
 		parent:    socket,
77
 		encryptor: encryptor,
80
 		encryptor: encryptor,

+ 3
- 2
wrappers/stats.go ファイルの表示

6
 
6
 
7
 	"go.uber.org/zap"
7
 	"go.uber.org/zap"
8
 
8
 
9
+	"github.com/9seconds/mtg/conntypes"
9
 	"github.com/9seconds/mtg/stats"
10
 	"github.com/9seconds/mtg/stats"
10
 )
11
 )
11
 
12
 
12
 type wrapperStats struct {
13
 type wrapperStats struct {
13
-	parent StreamReadWriteCloser
14
+	parent conntypes.StreamReadWriteCloser
14
 }
15
 }
15
 
16
 
16
 func (w *wrapperStats) Write(p []byte) (int, error) {
17
 func (w *wrapperStats) Write(p []byte) (int, error) {
61
 	return w.parent.Close()
62
 	return w.parent.Close()
62
 }
63
 }
63
 
64
 
64
-func NewTraffic(parent StreamReadWriteCloser) StreamReadWriteCloser {
65
+func NewTraffic(parent conntypes.StreamReadWriteCloser) conntypes.StreamReadWriteCloser {
65
 	return &wrapperStats{parent}
66
 	return &wrapperStats{parent}
66
 }
67
 }

+ 61
- 0
wrappers/timeout.go ファイルの表示

1
+package wrappers
2
+
3
+import (
4
+	"net"
5
+	"time"
6
+
7
+	"go.uber.org/zap"
8
+
9
+	"github.com/9seconds/mtg/conntypes"
10
+)
11
+
12
+const (
13
+	timeoutRead  = 2 * time.Minute
14
+	timeoutWrite = 2 * time.Minute
15
+)
16
+
17
+type wrapperTimeout struct {
18
+	parent conntypes.StreamReadWriteCloser
19
+}
20
+
21
+func (w *wrapperTimeout) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
22
+	return w.parent.WriteTimeout(p, timeout)
23
+}
24
+
25
+func (w *wrapperTimeout) Write(p []byte) (int, error) {
26
+	return w.parent.WriteTimeout(p, timeoutWrite)
27
+}
28
+
29
+func (w *wrapperTimeout) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
30
+	return w.parent.ReadTimeout(p, timeout)
31
+}
32
+
33
+func (w *wrapperTimeout) Read(p []byte) (int, error) {
34
+	return w.parent.ReadTimeout(p, timeoutRead)
35
+}
36
+
37
+func (w *wrapperTimeout) Close() error {
38
+	return w.parent.Close()
39
+}
40
+
41
+func (w *wrapperTimeout) Conn() net.Conn {
42
+	return w.parent.Conn()
43
+}
44
+
45
+func (w *wrapperTimeout) Logger() *zap.SugaredLogger {
46
+	return w.parent.Logger().Named("timeout")
47
+}
48
+
49
+func (w *wrapperTimeout) LocalAddr() *net.TCPAddr {
50
+	return w.parent.LocalAddr()
51
+}
52
+
53
+func (w *wrapperTimeout) RemoteAddr() *net.TCPAddr {
54
+	return w.parent.RemoteAddr()
55
+}
56
+
57
+func NewTimeout(parent conntypes.StreamReadWriteCloser) conntypes.StreamReadWriteCloser {
58
+	return &wrapperTimeout{
59
+		parent: parent,
60
+	}
61
+}

読み込み中…
キャンセル
保存