Browse Source

Use contexts for Conn wrapper

tags/0.11^2
9seconds 7 years ago
parent
commit
9f20e8749a
9 changed files with 73 additions and 32 deletions
  1. 3
    1
      client/client.go
  2. 4
    2
      client/direct.go
  3. 4
    2
      client/middle.go
  4. 7
    4
      proxy/proxy.go
  5. 5
    2
      telegram/dialer.go
  6. 4
    2
      telegram/direct.go
  7. 3
    2
      telegram/middle_caller.go
  8. 4
    3
      telegram/telegram.go
  9. 39
    14
      wrappers/conn.go

+ 3
- 1
client/client.go View File

1
 package client
1
 package client
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"net"
5
 	"net"
5
 
6
 
6
 	"github.com/9seconds/mtg/config"
7
 	"github.com/9seconds/mtg/config"
9
 )
10
 )
10
 
11
 
11
 // Init defines common method for initializing client connections.
12
 // Init defines common method for initializing client connections.
12
-type Init func(net.Conn, string, *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error)
13
+type Init func(context.Context, context.CancelFunc, net.Conn, string,
14
+	*config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error)

+ 4
- 2
client/direct.go View File

1
 package client
1
 package client
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"net"
5
 	"net"
5
 	"time"
6
 	"time"
6
 
7
 
16
 
17
 
17
 // DirectInit initializes client connection for proxy which connects to
18
 // DirectInit initializes client connection for proxy which connects to
18
 // Telegram directly.
19
 // Telegram directly.
19
-func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
20
+func DirectInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn,
21
+	connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
20
 	tcpSocket := socket.(*net.TCPConn)
22
 	tcpSocket := socket.(*net.TCPConn)
21
 	if err := tcpSocket.SetNoDelay(false); err != nil {
23
 	if err := tcpSocket.SetNoDelay(false); err != nil {
22
 		return nil, nil, errors.Annotate(err, "Cannot disable NO_DELAY to client socket")
24
 		return nil, nil, errors.Annotate(err, "Cannot disable NO_DELAY to client socket")
35
 	}
37
 	}
36
 	socket.SetReadDeadline(time.Time{}) // nolint: errcheck
38
 	socket.SetReadDeadline(time.Time{}) // nolint: errcheck
37
 
39
 
38
-	conn := wrappers.NewConn(socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6)
40
+	conn := wrappers.NewConn(ctx, cancel, socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6)
39
 	obfs2, connOpts, err := obfuscated2.ParseObfuscated2ClientFrame(conf.Secret, frame)
41
 	obfs2, connOpts, err := obfuscated2.ParseObfuscated2ClientFrame(conf.Secret, frame)
40
 	if err != nil {
42
 	if err != nil {
41
 		return nil, nil, errors.Annotate(err, "Cannot parse obfuscated frame")
43
 		return nil, nil, errors.Annotate(err, "Cannot parse obfuscated frame")

+ 4
- 2
client/middle.go View File

1
 package client
1
 package client
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"net"
5
 	"net"
5
 
6
 
6
 	"github.com/9seconds/mtg/config"
7
 	"github.com/9seconds/mtg/config"
10
 
11
 
11
 // MiddleInit initializes client connection for proxy which has to
12
 // MiddleInit initializes client connection for proxy which has to
12
 // support promoted channels, connect to Telegram middle proxies etc.
13
 // support promoted channels, connect to Telegram middle proxies etc.
13
-func MiddleInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
14
-	conn, opts, err := DirectInit(socket, connID, conf)
14
+func MiddleInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn,
15
+	connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
16
+	conn, opts, err := DirectInit(ctx, cancel, socket, connID, conf)
15
 	if err != nil {
17
 	if err != nil {
16
 		return nil, nil, err
18
 		return nil, nil, err
17
 	}
19
 	}

+ 7
- 4
proxy/proxy.go View File

1
 package proxy
1
 package proxy
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"io"
5
 	"io"
5
 	"net"
6
 	"net"
6
 	"sync"
7
 	"sync"
43
 func (p *Proxy) accept(conn net.Conn) {
44
 func (p *Proxy) accept(conn net.Conn) {
44
 	connID := uuid.NewV4().String()
45
 	connID := uuid.NewV4().String()
45
 	log := zap.S().With("connection_id", connID).Named("main")
46
 	log := zap.S().With("connection_id", connID).Named("main")
47
+	ctx, cancel := context.WithCancel(context.Background())
46
 
48
 
47
 	defer func() {
49
 	defer func() {
48
 		conn.Close() // nolint: errcheck
50
 		conn.Close() // nolint: errcheck
55
 
57
 
56
 	log.Infow("Client connected", "addr", conn.RemoteAddr())
58
 	log.Infow("Client connected", "addr", conn.RemoteAddr())
57
 
59
 
58
-	clientConn, opts, err := p.clientInit(conn, connID, p.conf)
60
+	clientConn, opts, err := p.clientInit(ctx, cancel, conn, connID, p.conf)
59
 	if err != nil {
61
 	if err != nil {
60
 		log.Errorw("Cannot initialize client connection", "error", err)
62
 		log.Errorw("Cannot initialize client connection", "error", err)
61
 		return
63
 		return
65
 	stats.ClientConnected(opts.ConnectionType, clientConn.RemoteAddr())
67
 	stats.ClientConnected(opts.ConnectionType, clientConn.RemoteAddr())
66
 	defer stats.ClientDisconnected(opts.ConnectionType, clientConn.RemoteAddr())
68
 	defer stats.ClientDisconnected(opts.ConnectionType, clientConn.RemoteAddr())
67
 
69
 
68
-	serverConn, err := p.getTelegramConn(opts, connID)
70
+	serverConn, err := p.getTelegramConn(ctx, cancel, opts, connID)
69
 	if err != nil {
71
 	if err != nil {
70
 		log.Errorw("Cannot initialize server connection", "error", err)
72
 		log.Errorw("Cannot initialize server connection", "error", err)
71
 		return
73
 		return
92
 	log.Infow("Client disconnected", "addr", conn.RemoteAddr())
94
 	log.Infow("Client disconnected", "addr", conn.RemoteAddr())
93
 }
95
 }
94
 
96
 
95
-func (p *Proxy) getTelegramConn(opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) {
96
-	streamConn, err := p.tg.Dial(connID, opts)
97
+func (p *Proxy) getTelegramConn(ctx context.Context, cancel context.CancelFunc,
98
+	opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) {
99
+	streamConn, err := p.tg.Dial(ctx, cancel, connID, opts)
97
 	if err != nil {
100
 	if err != nil {
98
 		return nil, errors.Annotate(err, "Cannot dial to Telegram")
101
 		return nil, errors.Annotate(err, "Cannot dial to Telegram")
99
 	}
102
 	}

+ 5
- 2
telegram/dialer.go View File

1
 package telegram
1
 package telegram
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"net"
5
 	"net"
5
 	"time"
6
 	"time"
6
 
7
 
38
 	return conn, nil
39
 	return conn, nil
39
 }
40
 }
40
 
41
 
41
-func (t *tgDialer) dialRWC(addr, connID string) (wrappers.StreamReadWriteCloser, error) {
42
+func (t *tgDialer) dialRWC(ctx context.Context, cancel context.CancelFunc,
43
+	addr, connID string) (wrappers.StreamReadWriteCloser, error) {
42
 	conn, err := t.dial(addr)
44
 	conn, err := t.dial(addr)
43
 	if err != nil {
45
 	if err != nil {
44
 		return nil, err
46
 		return nil, err
45
 	}
47
 	}
46
-	tgConn := wrappers.NewConn(conn, connID, wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6)
48
+	tgConn := wrappers.NewConn(ctx, cancel, conn, connID,
49
+		wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6)
47
 
50
 
48
 	return tgConn, nil
51
 	return tgConn, nil
49
 }
52
 }

+ 4
- 2
telegram/direct.go View File

1
 package telegram
1
 package telegram
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"net"
5
 	"net"
5
 
6
 
6
 	"github.com/juju/errors"
7
 	"github.com/juju/errors"
32
 	baseTelegram
33
 	baseTelegram
33
 }
34
 }
34
 
35
 
35
-func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
36
+func (t *directTelegram) Dial(ctx context.Context, cancel context.CancelFunc,
37
+	connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
36
 	dc := connOpts.DC
38
 	dc := connOpts.DC
37
 	if dc < 0 {
39
 	if dc < 0 {
38
 		dc = -dc
40
 		dc = -dc
40
 		dc = 1
42
 		dc = 1
41
 	}
43
 	}
42
 
44
 
43
-	return t.baseTelegram.dial(dc-1, connID, connOpts.ConnectionProto)
45
+	return t.baseTelegram.dial(ctx, cancel, dc-1, connID, connOpts.ConnectionProto)
44
 }
46
 }
45
 
47
 
46
 func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts,
48
 func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts,

+ 3
- 2
telegram/middle_caller.go View File

2
 
2
 
3
 import (
3
 import (
4
 	"bufio"
4
 	"bufio"
5
+	"context"
5
 	"io/ioutil"
6
 	"io/ioutil"
6
 	"net"
7
 	"net"
7
 	"net/http"
8
 	"net/http"
38
 	httpClient  *http.Client
39
 	httpClient  *http.Client
39
 }
40
 }
40
 
41
 
41
-func (t *middleTelegramCaller) Dial(connID string,
42
+func (t *middleTelegramCaller) Dial(ctx context.Context, cancel context.CancelFunc, connID string,
42
 	connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
43
 	connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
43
 	dc := connOpts.DC
44
 	dc := connOpts.DC
44
 	if dc == 0 {
45
 	if dc == 0 {
47
 	t.dialerMutex.RLock()
48
 	t.dialerMutex.RLock()
48
 	defer t.dialerMutex.RUnlock()
49
 	defer t.dialerMutex.RUnlock()
49
 
50
 
50
-	return t.baseTelegram.dial(dc, connID, connOpts.ConnectionProto)
51
+	return t.baseTelegram.dial(ctx, cancel, dc, connID, connOpts.ConnectionProto)
51
 }
52
 }
52
 
53
 
53
 func (t *middleTelegramCaller) autoUpdate() {
54
 func (t *middleTelegramCaller) autoUpdate() {

+ 4
- 3
telegram/telegram.go View File

1
 package telegram
1
 package telegram
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"math/rand"
5
 	"math/rand"
5
 
6
 
6
 	"github.com/juju/errors"
7
 	"github.com/juju/errors"
11
 
12
 
12
 // Telegram is an interface for different Telegram work modes.
13
 // Telegram is an interface for different Telegram work modes.
13
 type Telegram interface {
14
 type Telegram interface {
14
-	Dial(string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error)
15
+	Dial(context.Context, context.CancelFunc, string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error)
15
 	Init(*mtproto.ConnectionOpts, wrappers.StreamReadWriteCloser) (wrappers.Wrap, error)
16
 	Init(*mtproto.ConnectionOpts, wrappers.StreamReadWriteCloser) (wrappers.Wrap, error)
16
 }
17
 }
17
 
18
 
22
 	v6Addresses map[int16][]string
23
 	v6Addresses map[int16][]string
23
 }
24
 }
24
 
25
 
25
-func (b *baseTelegram) dial(dcIdx int16, connID string,
26
+func (b *baseTelegram) dial(ctx context.Context, cancel context.CancelFunc, dcIdx int16, connID string,
26
 	proto mtproto.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
27
 	proto mtproto.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
27
 	addrs := make([]string, 2)
28
 	addrs := make([]string, 2)
28
 
29
 
38
 	}
39
 	}
39
 
40
 
40
 	for _, addr := range addrs {
41
 	for _, addr := range addrs {
41
-		if conn, err := b.dialer.dialRWC(addr, connID); err == nil {
42
+		if conn, err := b.dialer.dialRWC(ctx, cancel, addr, connID); err == nil {
42
 			return conn, err
43
 			return conn, err
43
 		}
44
 		}
44
 	}
45
 	}

+ 39
- 14
wrappers/conn.go View File

1
 package wrappers
1
 package wrappers
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"net"
5
 	"net"
5
 	"time"
6
 	"time"
6
 
7
 
7
 	"go.uber.org/zap"
8
 	"go.uber.org/zap"
8
 
9
 
9
 	"github.com/9seconds/mtg/stats"
10
 	"github.com/9seconds/mtg/stats"
11
+	"github.com/juju/errors"
10
 )
12
 )
11
 
13
 
12
 // ConnPurpose is intended to be identifier of connection purpose. We
14
 // ConnPurpose is intended to be identifier of connection purpose. We
39
 // Conn is a basic wrapper for net.Conn providing the most low-level
41
 // Conn is a basic wrapper for net.Conn providing the most low-level
40
 // logic and management as possible.
42
 // logic and management as possible.
41
 type Conn struct {
43
 type Conn struct {
42
-	connID string
43
 	conn   net.Conn
44
 	conn   net.Conn
45
+	ctx    context.Context
46
+	cancel context.CancelFunc
47
+	connID string
44
 	logger *zap.SugaredLogger
48
 	logger *zap.SugaredLogger
45
 
49
 
46
 	publicIPv4 net.IP
50
 	publicIPv4 net.IP
48
 }
52
 }
49
 
53
 
50
 func (c *Conn) Write(p []byte) (int, error) {
54
 func (c *Conn) Write(p []byte) (int, error) {
51
-	c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck
52
-	n, err := c.conn.Write(p)
55
+	select {
56
+	case <-c.ctx.Done():
57
+		return 0, errors.Annotate(c.ctx.Err(), "Cannot write because context was closed")
58
+	default:
59
+		c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck
60
+		n, err := c.conn.Write(p)
61
+		if err != nil {
62
+			c.cancel()
63
+		}
53
 
64
 
54
-	c.logger.Debugw("Write to stream", "bytes", n, "error", err)
55
-	stats.EgressTraffic(n)
65
+		c.logger.Debugw("Write to stream", "bytes", n, "error", err)
66
+		stats.EgressTraffic(n)
56
 
67
 
57
-	return n, err
68
+		return n, err
69
+	}
58
 }
70
 }
59
 
71
 
60
 func (c *Conn) Read(p []byte) (int, error) {
72
 func (c *Conn) Read(p []byte) (int, error) {
61
-	c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck
62
-	n, err := c.conn.Read(p)
73
+	select {
74
+	case <-c.ctx.Done():
75
+		return 0, errors.Annotate(c.ctx.Err(), "Cannot read because context was closed")
76
+	default:
77
+		c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck
78
+		n, err := c.conn.Read(p)
79
+		if err != nil {
80
+			c.cancel()
81
+		}
63
 
82
 
64
-	c.logger.Debugw("Read from stream", "bytes", n, "error", err)
65
-	stats.IngressTraffic(n)
83
+		c.logger.Debugw("Read from stream", "bytes", n, "error", err)
84
+		stats.IngressTraffic(n)
66
 
85
 
67
-	return n, err
86
+		return n, err
87
+	}
68
 }
88
 }
69
 
89
 
70
 // Close closes underlying net.Conn instance.
90
 // Close closes underlying net.Conn instance.
71
 func (c *Conn) Close() error {
91
 func (c *Conn) Close() error {
72
 	defer c.logger.Debugw("Close connection")
92
 	defer c.logger.Debugw("Close connection")
93
+
94
+	c.cancel()
73
 	return c.conn.Close()
95
 	return c.conn.Close()
74
 }
96
 }
75
 
97
 
100
 }
122
 }
101
 
123
 
102
 // NewConn initializes Conn wrapper for net.Conn.
124
 // NewConn initializes Conn wrapper for net.Conn.
103
-func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser {
125
+func NewConn(ctx context.Context, cancel context.CancelFunc, conn net.Conn,
126
+	connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser {
104
 	logger := zap.S().With(
127
 	logger := zap.S().With(
105
 		"connection_id", connID,
128
 		"connection_id", connID,
106
 		"local_address", conn.LocalAddr(),
129
 		"local_address", conn.LocalAddr(),
109
 	).Named("conn")
132
 	).Named("conn")
110
 
133
 
111
 	wrapper := Conn{
134
 	wrapper := Conn{
112
-		logger:     logger,
113
-		connID:     connID,
114
 		conn:       conn,
135
 		conn:       conn,
136
+		ctx:        ctx,
137
+		cancel:     cancel,
138
+		connID:     connID,
139
+		logger:     logger,
115
 		publicIPv4: publicIPv4,
140
 		publicIPv4: publicIPv4,
116
 		publicIPv6: publicIPv6,
141
 		publicIPv6: publicIPv6,
117
 	}
142
 	}

Loading…
Cancel
Save