Просмотр исходного кода

Merge pull request #33 from 9seconds/contexts

Use contexts for Conn wrapper
tags/0.11^0
Sergey Arkhipov 7 лет назад
Родитель
Сommit
243a89a68c
Аккаунт пользователя с таким Email не найден
9 измененных файлов: 73 добавлений и 32 удалений
  1. 3
    1
      client/client.go
  2. 4
    2
      client/direct.go
  3. 4
    2
      client/middle.go
  4. 7
    4
      proxy/proxy.go
  5. 5
    2
      telegram/dialer.go
  6. 4
    2
      telegram/direct.go
  7. 3
    2
      telegram/middle_caller.go
  8. 4
    3
      telegram/telegram.go
  9. 39
    14
      wrappers/conn.go

+ 3
- 1
client/client.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package client
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 
6 7
 	"github.com/9seconds/mtg/config"
@@ -9,4 +10,5 @@ import (
9 10
 )
10 11
 
11 12
 // Init defines common method for initializing client connections.
12
-type Init func(net.Conn, string, *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error)
13
+type Init func(context.Context, context.CancelFunc, net.Conn, string,
14
+	*config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error)

+ 4
- 2
client/direct.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package client
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 	"time"
6 7
 
@@ -16,7 +17,8 @@ const handshakeTimeout = 10 * time.Second
16 17
 
17 18
 // DirectInit initializes client connection for proxy which connects to
18 19
 // Telegram directly.
19
-func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
20
+func DirectInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn,
21
+	connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
20 22
 	tcpSocket := socket.(*net.TCPConn)
21 23
 	if err := tcpSocket.SetNoDelay(false); err != nil {
22 24
 		return nil, nil, errors.Annotate(err, "Cannot disable NO_DELAY to client socket")
@@ -35,7 +37,7 @@ func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.W
35 37
 	}
36 38
 	socket.SetReadDeadline(time.Time{}) // nolint: errcheck
37 39
 
38
-	conn := wrappers.NewConn(socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6)
40
+	conn := wrappers.NewConn(ctx, cancel, socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6)
39 41
 	obfs2, connOpts, err := obfuscated2.ParseObfuscated2ClientFrame(conf.Secret, frame)
40 42
 	if err != nil {
41 43
 		return nil, nil, errors.Annotate(err, "Cannot parse obfuscated frame")

+ 4
- 2
client/middle.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package client
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 
6 7
 	"github.com/9seconds/mtg/config"
@@ -10,8 +11,9 @@ import (
10 11
 
11 12
 // MiddleInit initializes client connection for proxy which has to
12 13
 // support promoted channels, connect to Telegram middle proxies etc.
13
-func MiddleInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
14
-	conn, opts, err := DirectInit(socket, connID, conf)
14
+func MiddleInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn,
15
+	connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
16
+	conn, opts, err := DirectInit(ctx, cancel, socket, connID, conf)
15 17
 	if err != nil {
16 18
 		return nil, nil, err
17 19
 	}

+ 7
- 4
proxy/proxy.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package proxy
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"io"
5 6
 	"net"
6 7
 	"sync"
@@ -43,6 +44,7 @@ func (p *Proxy) Serve() error {
43 44
 func (p *Proxy) accept(conn net.Conn) {
44 45
 	connID := uuid.NewV4().String()
45 46
 	log := zap.S().With("connection_id", connID).Named("main")
47
+	ctx, cancel := context.WithCancel(context.Background())
46 48
 
47 49
 	defer func() {
48 50
 		conn.Close() // nolint: errcheck
@@ -55,7 +57,7 @@ func (p *Proxy) accept(conn net.Conn) {
55 57
 
56 58
 	log.Infow("Client connected", "addr", conn.RemoteAddr())
57 59
 
58
-	clientConn, opts, err := p.clientInit(conn, connID, p.conf)
60
+	clientConn, opts, err := p.clientInit(ctx, cancel, conn, connID, p.conf)
59 61
 	if err != nil {
60 62
 		log.Errorw("Cannot initialize client connection", "error", err)
61 63
 		return
@@ -65,7 +67,7 @@ func (p *Proxy) accept(conn net.Conn) {
65 67
 	stats.ClientConnected(opts.ConnectionType, clientConn.RemoteAddr())
66 68
 	defer stats.ClientDisconnected(opts.ConnectionType, clientConn.RemoteAddr())
67 69
 
68
-	serverConn, err := p.getTelegramConn(opts, connID)
70
+	serverConn, err := p.getTelegramConn(ctx, cancel, opts, connID)
69 71
 	if err != nil {
70 72
 		log.Errorw("Cannot initialize server connection", "error", err)
71 73
 		return
@@ -92,8 +94,9 @@ func (p *Proxy) accept(conn net.Conn) {
92 94
 	log.Infow("Client disconnected", "addr", conn.RemoteAddr())
93 95
 }
94 96
 
95
-func (p *Proxy) getTelegramConn(opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) {
96
-	streamConn, err := p.tg.Dial(connID, opts)
97
+func (p *Proxy) getTelegramConn(ctx context.Context, cancel context.CancelFunc,
98
+	opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) {
99
+	streamConn, err := p.tg.Dial(ctx, cancel, connID, opts)
97 100
 	if err != nil {
98 101
 		return nil, errors.Annotate(err, "Cannot dial to Telegram")
99 102
 	}

+ 5
- 2
telegram/dialer.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package telegram
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 	"time"
6 7
 
@@ -38,12 +39,14 @@ func (t *tgDialer) dial(addr string) (net.Conn, error) {
38 39
 	return conn, nil
39 40
 }
40 41
 
41
-func (t *tgDialer) dialRWC(addr, connID string) (wrappers.StreamReadWriteCloser, error) {
42
+func (t *tgDialer) dialRWC(ctx context.Context, cancel context.CancelFunc,
43
+	addr, connID string) (wrappers.StreamReadWriteCloser, error) {
42 44
 	conn, err := t.dial(addr)
43 45
 	if err != nil {
44 46
 		return nil, err
45 47
 	}
46
-	tgConn := wrappers.NewConn(conn, connID, wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6)
48
+	tgConn := wrappers.NewConn(ctx, cancel, conn, connID,
49
+		wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6)
47 50
 
48 51
 	return tgConn, nil
49 52
 }

+ 4
- 2
telegram/direct.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package telegram
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 
6 7
 	"github.com/juju/errors"
@@ -32,7 +33,8 @@ type directTelegram struct {
32 33
 	baseTelegram
33 34
 }
34 35
 
35
-func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
36
+func (t *directTelegram) Dial(ctx context.Context, cancel context.CancelFunc,
37
+	connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
36 38
 	dc := connOpts.DC
37 39
 	if dc < 0 {
38 40
 		dc = -dc
@@ -40,7 +42,7 @@ func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) (
40 42
 		dc = 1
41 43
 	}
42 44
 
43
-	return t.baseTelegram.dial(dc-1, connID, connOpts.ConnectionProto)
45
+	return t.baseTelegram.dial(ctx, cancel, dc-1, connID, connOpts.ConnectionProto)
44 46
 }
45 47
 
46 48
 func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts,

+ 3
- 2
telegram/middle_caller.go Просмотреть файл

@@ -2,6 +2,7 @@ package telegram
2 2
 
3 3
 import (
4 4
 	"bufio"
5
+	"context"
5 6
 	"io/ioutil"
6 7
 	"net"
7 8
 	"net/http"
@@ -38,7 +39,7 @@ type middleTelegramCaller struct {
38 39
 	httpClient  *http.Client
39 40
 }
40 41
 
41
-func (t *middleTelegramCaller) Dial(connID string,
42
+func (t *middleTelegramCaller) Dial(ctx context.Context, cancel context.CancelFunc, connID string,
42 43
 	connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
43 44
 	dc := connOpts.DC
44 45
 	if dc == 0 {
@@ -47,7 +48,7 @@ func (t *middleTelegramCaller) Dial(connID string,
47 48
 	t.dialerMutex.RLock()
48 49
 	defer t.dialerMutex.RUnlock()
49 50
 
50
-	return t.baseTelegram.dial(dc, connID, connOpts.ConnectionProto)
51
+	return t.baseTelegram.dial(ctx, cancel, dc, connID, connOpts.ConnectionProto)
51 52
 }
52 53
 
53 54
 func (t *middleTelegramCaller) autoUpdate() {

+ 4
- 3
telegram/telegram.go Просмотреть файл

@@ -1,6 +1,7 @@
1 1
 package telegram
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"math/rand"
5 6
 
6 7
 	"github.com/juju/errors"
@@ -11,7 +12,7 @@ import (
11 12
 
12 13
 // Telegram is an interface for different Telegram work modes.
13 14
 type Telegram interface {
14
-	Dial(string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error)
15
+	Dial(context.Context, context.CancelFunc, string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error)
15 16
 	Init(*mtproto.ConnectionOpts, wrappers.StreamReadWriteCloser) (wrappers.Wrap, error)
16 17
 }
17 18
 
@@ -22,7 +23,7 @@ type baseTelegram struct {
22 23
 	v6Addresses map[int16][]string
23 24
 }
24 25
 
25
-func (b *baseTelegram) dial(dcIdx int16, connID string,
26
+func (b *baseTelegram) dial(ctx context.Context, cancel context.CancelFunc, dcIdx int16, connID string,
26 27
 	proto mtproto.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
27 28
 	addrs := make([]string, 2)
28 29
 
@@ -38,7 +39,7 @@ func (b *baseTelegram) dial(dcIdx int16, connID string,
38 39
 	}
39 40
 
40 41
 	for _, addr := range addrs {
41
-		if conn, err := b.dialer.dialRWC(addr, connID); err == nil {
42
+		if conn, err := b.dialer.dialRWC(ctx, cancel, addr, connID); err == nil {
42 43
 			return conn, err
43 44
 		}
44 45
 	}

+ 39
- 14
wrappers/conn.go Просмотреть файл

@@ -1,12 +1,14 @@
1 1
 package wrappers
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 	"time"
6 7
 
7 8
 	"go.uber.org/zap"
8 9
 
9 10
 	"github.com/9seconds/mtg/stats"
11
+	"github.com/juju/errors"
10 12
 )
11 13
 
12 14
 // ConnPurpose is intended to be identifier of connection purpose. We
@@ -39,8 +41,10 @@ const (
39 41
 // Conn is a basic wrapper for net.Conn providing the most low-level
40 42
 // logic and management as possible.
41 43
 type Conn struct {
42
-	connID string
43 44
 	conn   net.Conn
45
+	ctx    context.Context
46
+	cancel context.CancelFunc
47
+	connID string
44 48
 	logger *zap.SugaredLogger
45 49
 
46 50
 	publicIPv4 net.IP
@@ -48,28 +52,46 @@ type Conn struct {
48 52
 }
49 53
 
50 54
 func (c *Conn) Write(p []byte) (int, error) {
51
-	c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck
52
-	n, err := c.conn.Write(p)
55
+	select {
56
+	case <-c.ctx.Done():
57
+		return 0, errors.Annotate(c.ctx.Err(), "Cannot write because context was closed")
58
+	default:
59
+		c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck
60
+		n, err := c.conn.Write(p)
61
+		if err != nil {
62
+			c.cancel()
63
+		}
53 64
 
54
-	c.logger.Debugw("Write to stream", "bytes", n, "error", err)
55
-	stats.EgressTraffic(n)
65
+		c.logger.Debugw("Write to stream", "bytes", n, "error", err)
66
+		stats.EgressTraffic(n)
56 67
 
57
-	return n, err
68
+		return n, err
69
+	}
58 70
 }
59 71
 
60 72
 func (c *Conn) Read(p []byte) (int, error) {
61
-	c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck
62
-	n, err := c.conn.Read(p)
73
+	select {
74
+	case <-c.ctx.Done():
75
+		return 0, errors.Annotate(c.ctx.Err(), "Cannot read because context was closed")
76
+	default:
77
+		c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck
78
+		n, err := c.conn.Read(p)
79
+		if err != nil {
80
+			c.cancel()
81
+		}
63 82
 
64
-	c.logger.Debugw("Read from stream", "bytes", n, "error", err)
65
-	stats.IngressTraffic(n)
83
+		c.logger.Debugw("Read from stream", "bytes", n, "error", err)
84
+		stats.IngressTraffic(n)
66 85
 
67
-	return n, err
86
+		return n, err
87
+	}
68 88
 }
69 89
 
70 90
 // Close closes underlying net.Conn instance.
71 91
 func (c *Conn) Close() error {
72 92
 	defer c.logger.Debugw("Close connection")
93
+
94
+	c.cancel()
73 95
 	return c.conn.Close()
74 96
 }
75 97
 
@@ -100,7 +122,8 @@ func (c *Conn) RemoteAddr() *net.TCPAddr {
100 122
 }
101 123
 
102 124
 // NewConn initializes Conn wrapper for net.Conn.
103
-func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser {
125
+func NewConn(ctx context.Context, cancel context.CancelFunc, conn net.Conn,
126
+	connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser {
104 127
 	logger := zap.S().With(
105 128
 		"connection_id", connID,
106 129
 		"local_address", conn.LocalAddr(),
@@ -109,9 +132,11 @@ func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publ
109 132
 	).Named("conn")
110 133
 
111 134
 	wrapper := Conn{
112
-		logger:     logger,
113
-		connID:     connID,
114 135
 		conn:       conn,
136
+		ctx:        ctx,
137
+		cancel:     cancel,
138
+		connID:     connID,
139
+		logger:     logger,
115 140
 		publicIPv4: publicIPv4,
116 141
 		publicIPv6: publicIPv6,
117 142
 	}

Загрузка…
Отмена
Сохранить