Просмотр исходного кода

Use CloseRead and CloseWrites

tags/v2.1.3^2
9seconds 4 лет назад
Родитель
Сommit
ffad717829

+ 17
- 0
essentials/conns.go Просмотреть файл

1
+package essentials
2
+
3
+import "net"
4
+
5
+type CloseableReader interface {
6
+	CloseRead() error
7
+}
8
+
9
+type CloseableWriter interface {
10
+	CloseWrite() error
11
+}
12
+
13
+type Conn interface {
14
+	net.Conn
15
+	CloseableReader
16
+	CloseableWriter
17
+}

+ 4
- 0
go.mod Просмотреть файл

38
 	github.com/gotd/ige v0.1.5 // indirect
38
 	github.com/gotd/ige v0.1.5 // indirect
39
 	github.com/gotd/xor v0.1.1 // indirect
39
 	github.com/gotd/xor v0.1.1 // indirect
40
 	github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
40
 	github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
41
+	github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
41
 	github.com/pkg/errors v0.9.1 // indirect
42
 	github.com/pkg/errors v0.9.1 // indirect
42
 	github.com/pmezard/go-difflib v1.0.0 // indirect
43
 	github.com/pmezard/go-difflib v1.0.0 // indirect
43
 	github.com/prometheus/client_model v0.2.0 // indirect
44
 	github.com/prometheus/client_model v0.2.0 // indirect
45
+	github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf // indirect
46
+	github.com/txthinking/socks5 v0.0.0-20211121111206-e03c1217a50b // indirect
47
+	github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe // indirect
44
 	go.uber.org/atomic v1.7.0 // indirect
48
 	go.uber.org/atomic v1.7.0 // indirect
45
 	go.uber.org/multierr v1.6.0 // indirect
49
 	go.uber.org/multierr v1.6.0 // indirect
46
 	go.uber.org/zap v1.16.0 // indirect
50
 	go.uber.org/zap v1.16.0 // indirect

+ 8
- 0
go.sum Просмотреть файл

195
 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
195
 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
196
 github.com/panjf2000/ants/v2 v2.4.6 h1:drmj9mcygn2gawZ155dRbo+NfXEfAssjZNU1qoIb4gQ=
196
 github.com/panjf2000/ants/v2 v2.4.6 h1:drmj9mcygn2gawZ155dRbo+NfXEfAssjZNU1qoIb4gQ=
197
 github.com/panjf2000/ants/v2 v2.4.6/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A=
197
 github.com/panjf2000/ants/v2 v2.4.6/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A=
198
+github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
199
+github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
198
 github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
200
 github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
199
 github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
201
 github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
200
 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
202
 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
250
 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
252
 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
251
 github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
253
 github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
252
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
254
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
255
+github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf h1:7PflaKRtU4np/epFxRXlFhlzLXZzKFrH5/I4so5Ove0=
256
+github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf/go.mod h1:CLUSJbazqETbaR+i0YAhXBICV9TrKH93pziccMhmhpM=
257
+github.com/txthinking/socks5 v0.0.0-20211121111206-e03c1217a50b h1:6J/38A0Xmdnjacfie0Udams7OP/GdoExyTipKwuQWjY=
258
+github.com/txthinking/socks5 v0.0.0-20211121111206-e03c1217a50b/go.mod h1:7NloQcrxaZYKURWph5HLxVDlIwMHJXCPkeWPtpftsIg=
259
+github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe h1:gMWxZxBFRAXqoGkwkYlPX2zvyyKNWJpxOxCrjqJkm5A=
260
+github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe/go.mod h1:WgqbSEmUYSjEV3B1qmee/PpP2NYEz4bL9/+mF1ma+s4=
253
 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43 h1:QEePdg0ty2r0t1+qwfZmQ4OOl/MB2UXIeJSpIZv56lg=
261
 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43 h1:QEePdg0ty2r0t1+qwfZmQ4OOl/MB2UXIeJSpIZv56lg=
254
 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43/go.mod h1:OYRfF6eb5wY9VRFkXJH8FFBi3plw2v+giaIu7P054pM=
262
 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43/go.mod h1:OYRfF6eb5wY9VRFkXJH8FFBi3plw2v+giaIu7P054pM=
255
 github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
263
 github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

+ 2
- 1
internal/cli/access.go Просмотреть файл

13
 	"strings"
13
 	"strings"
14
 	"sync"
14
 	"sync"
15
 
15
 
16
+	"github.com/9seconds/mtg/v2/essentials"
16
 	"github.com/9seconds/mtg/v2/internal/config"
17
 	"github.com/9seconds/mtg/v2/internal/config"
17
 	"github.com/9seconds/mtg/v2/internal/utils"
18
 	"github.com/9seconds/mtg/v2/internal/utils"
18
 	"github.com/9seconds/mtg/v2/mtglib"
19
 	"github.com/9seconds/mtg/v2/mtglib"
106
 }
107
 }
107
 
108
 
108
 func (a *Access) getIP(ntw mtglib.Network, protocol string) net.IP {
109
 func (a *Access) getIP(ntw mtglib.Network, protocol string) net.IP {
109
-	client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (net.Conn, error) {
110
+	client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error) {
110
 		return ntw.DialContext(ctx, protocol, address) // nolint: wrapcheck
111
 		return ntw.DialContext(ctx, protocol, address) // nolint: wrapcheck
111
 	})
112
 	})
112
 
113
 

+ 6
- 6
internal/testlib/mtglib_network_mock.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
-	"net"
6
 	"net/http"
5
 	"net/http"
7
 
6
 
7
+	"github.com/9seconds/mtg/v2/essentials"
8
 	"github.com/stretchr/testify/mock"
8
 	"github.com/stretchr/testify/mock"
9
 )
9
 )
10
 
10
 
12
 	mock.Mock
12
 	mock.Mock
13
 }
13
 }
14
 
14
 
15
-func (m *MtglibNetworkMock) Dial(network, address string) (net.Conn, error) {
15
+func (m *MtglibNetworkMock) Dial(network, address string) (essentials.Conn, error) {
16
 	args := m.Called(network, address)
16
 	args := m.Called(network, address)
17
 
17
 
18
-	return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck
18
+	return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck
19
 }
19
 }
20
 
20
 
21
-func (m *MtglibNetworkMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
21
+func (m *MtglibNetworkMock) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
22
 	args := m.Called(ctx, network, address)
22
 	args := m.Called(ctx, network, address)
23
 
23
 
24
-	return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck
24
+	return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck
25
 }
25
 }
26
 
26
 
27
 func (m *MtglibNetworkMock) MakeHTTPClient(dialFunc func(ctx context.Context,
27
 func (m *MtglibNetworkMock) MakeHTTPClient(dialFunc func(ctx context.Context,
28
-	network, address string) (net.Conn, error)) *http.Client {
28
+	network, address string) (essentials.Conn, error)) *http.Client {
29
 	return m.Called(dialFunc).Get(0).(*http.Client)
29
 	return m.Called(dialFunc).Get(0).(*http.Client)
30
 }
30
 }

+ 17
- 9
internal/testlib/net_conn_mock.go Просмотреть файл

7
 	"github.com/stretchr/testify/mock"
7
 	"github.com/stretchr/testify/mock"
8
 )
8
 )
9
 
9
 
10
-type NetConnMock struct {
10
+type EssentialsConnMock struct {
11
 	mock.Mock
11
 	mock.Mock
12
 }
12
 }
13
 
13
 
14
-func (n *NetConnMock) Read(b []byte) (int, error) {
14
+func (n *EssentialsConnMock) Read(b []byte) (int, error) {
15
 	args := n.Called(b)
15
 	args := n.Called(b)
16
 
16
 
17
 	return args.Int(0), args.Error(1)
17
 	return args.Int(0), args.Error(1)
18
 }
18
 }
19
 
19
 
20
-func (n *NetConnMock) Write(b []byte) (int, error) {
20
+func (n *EssentialsConnMock) Write(b []byte) (int, error) {
21
 	args := n.Called(b)
21
 	args := n.Called(b)
22
 
22
 
23
 	return args.Int(0), args.Error(1)
23
 	return args.Int(0), args.Error(1)
24
 }
24
 }
25
 
25
 
26
-func (n *NetConnMock) Close() error {
26
+func (n *EssentialsConnMock) Close() error {
27
 	return n.Called().Error(0) // nolint: wrapcheck
27
 	return n.Called().Error(0) // nolint: wrapcheck
28
 }
28
 }
29
 
29
 
30
-func (n *NetConnMock) LocalAddr() net.Addr {
30
+func (n *EssentialsConnMock) CloseRead() error {
31
+	return n.Called().Error(0) // nolint: wrapcheck
32
+}
33
+
34
+func (n *EssentialsConnMock) CloseWrite() error {
35
+	return n.Called().Error(0) // nolint: wrapcheck
36
+}
37
+
38
+func (n *EssentialsConnMock) LocalAddr() net.Addr {
31
 	return n.Called().Get(0).(net.Addr)
39
 	return n.Called().Get(0).(net.Addr)
32
 }
40
 }
33
 
41
 
34
-func (n *NetConnMock) RemoteAddr() net.Addr {
42
+func (n *EssentialsConnMock) RemoteAddr() net.Addr {
35
 	return n.Called().Get(0).(net.Addr)
43
 	return n.Called().Get(0).(net.Addr)
36
 }
44
 }
37
 
45
 
38
-func (n *NetConnMock) SetDeadline(t time.Time) error {
46
+func (n *EssentialsConnMock) SetDeadline(t time.Time) error {
39
 	return n.Called(t).Error(0) // nolint: wrapcheck
47
 	return n.Called(t).Error(0) // nolint: wrapcheck
40
 }
48
 }
41
 
49
 
42
-func (n *NetConnMock) SetReadDeadline(t time.Time) error {
50
+func (n *EssentialsConnMock) SetReadDeadline(t time.Time) error {
43
 	return n.Called(t).Error(0) // nolint: wrapcheck
51
 	return n.Called(t).Error(0) // nolint: wrapcheck
44
 }
52
 }
45
 
53
 
46
-func (n *NetConnMock) SetWriteDeadline(t time.Time) error {
54
+func (n *EssentialsConnMock) SetWriteDeadline(t time.Time) error {
47
 	return n.Called(t).Error(0) // nolint: wrapcheck
55
 	return n.Called(t).Error(0) // nolint: wrapcheck
48
 }
56
 }

+ 5
- 4
mtglib/conns.go Просмотреть файл

4
 	"bytes"
4
 	"bytes"
5
 	"context"
5
 	"context"
6
 	"io"
6
 	"io"
7
-	"net"
8
 	"sync"
7
 	"sync"
8
+
9
+	"github.com/9seconds/mtg/v2/essentials"
9
 )
10
 )
10
 
11
 
11
 type connTraffic struct {
12
 type connTraffic struct {
12
-	net.Conn
13
+	essentials.Conn
13
 
14
 
14
 	streamID string
15
 	streamID string
15
 	stream   EventStream
16
 	stream   EventStream
37
 }
38
 }
38
 
39
 
39
 type connRewind struct {
40
 type connRewind struct {
40
-	net.Conn
41
+	essentials.Conn
41
 
42
 
42
 	active io.Reader
43
 	active io.Reader
43
 	buf    bytes.Buffer
44
 	buf    bytes.Buffer
58
 	c.active = io.MultiReader(&c.buf, c.Conn)
59
 	c.active = io.MultiReader(&c.buf, c.Conn)
59
 }
60
 }
60
 
61
 
61
-func newConnRewind(conn net.Conn) *connRewind {
62
+func newConnRewind(conn essentials.Conn) *connRewind {
62
 	rv := &connRewind{
63
 	rv := &connRewind{
63
 		Conn: conn,
64
 		Conn: conn,
64
 	}
65
 	}

+ 3
- 3
mtglib/conns_internal_test.go Просмотреть файл

14
 )
14
 )
15
 
15
 
16
 type ConnRewindBaseConn struct {
16
 type ConnRewindBaseConn struct {
17
-	testlib.NetConnMock
17
+	testlib.EssentialsConnMock
18
 
18
 
19
 	readBuffer bytes.Buffer
19
 	readBuffer bytes.Buffer
20
 }
20
 }
29
 	suite.Suite
29
 	suite.Suite
30
 
30
 
31
 	eventStreamMock *EventStreamMock
31
 	eventStreamMock *EventStreamMock
32
-	connMock        *testlib.NetConnMock
32
+	connMock        *testlib.EssentialsConnMock
33
 	conn            io.ReadWriter
33
 	conn            io.ReadWriter
34
 }
34
 }
35
 
35
 
36
 func (suite *ConnTrafficTestSuite) SetupTest() {
36
 func (suite *ConnTrafficTestSuite) SetupTest() {
37
 	suite.eventStreamMock = &EventStreamMock{}
37
 	suite.eventStreamMock = &EventStreamMock{}
38
-	suite.connMock = &testlib.NetConnMock{}
38
+	suite.connMock = &testlib.EssentialsConnMock{}
39
 	suite.conn = connTraffic{
39
 	suite.conn = connTraffic{
40
 		Conn:     suite.connMock,
40
 		Conn:     suite.connMock,
41
 		streamID: "CONNID",
41
 		streamID: "CONNID",

+ 5
- 3
mtglib/init.go Просмотреть файл

23
 	"net"
23
 	"net"
24
 	"net/http"
24
 	"net/http"
25
 	"time"
25
 	"time"
26
+
27
+	"github.com/9seconds/mtg/v2/essentials"
26
 )
28
 )
27
 
29
 
28
 var (
30
 var (
116
 // 3. Doing HTTP requests (for example, for FireHOL ipblocklist).
118
 // 3. Doing HTTP requests (for example, for FireHOL ipblocklist).
117
 type Network interface {
119
 type Network interface {
118
 	// Dial establishes context-free TCP connections.
120
 	// Dial establishes context-free TCP connections.
119
-	Dial(network, address string) (net.Conn, error)
121
+	Dial(network, address string) (essentials.Conn, error)
120
 
122
 
121
 	// DialContext dials using a context. This is a preferrable
123
 	// DialContext dials using a context. This is a preferrable
122
 	// way of establishing TCP connections.
124
 	// way of establishing TCP connections.
123
-	DialContext(ctx context.Context, network, address string) (net.Conn, error)
125
+	DialContext(ctx context.Context, network, address string) (essentials.Conn, error)
124
 
126
 
125
 	// MakeHTTPClient build an HTTP client with given dial function. If
127
 	// MakeHTTPClient build an HTTP client with given dial function. If
126
 	// nothing is provided, then DialContext of this interface is going
128
 	// nothing is provided, then DialContext of this interface is going
127
 	// to be used.
129
 	// to be used.
128
-	MakeHTTPClient(func(ctx context.Context, network, address string) (net.Conn, error)) *http.Client
130
+	MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client
129
 }
131
 }
130
 
132
 
131
 // AntiReplayCache is an interface that is used to detect replay attacks
133
 // AntiReplayCache is an interface that is used to detect replay attacks

+ 2
- 2
mtglib/internal/faketls/conn.go Просмотреть файл

4
 	"bytes"
4
 	"bytes"
5
 	"fmt"
5
 	"fmt"
6
 	"math/rand"
6
 	"math/rand"
7
-	"net"
8
 
7
 
8
+	"github.com/9seconds/mtg/v2/essentials"
9
 	"github.com/9seconds/mtg/v2/mtglib/internal/faketls/record"
9
 	"github.com/9seconds/mtg/v2/mtglib/internal/faketls/record"
10
 )
10
 )
11
 
11
 
12
 type Conn struct {
12
 type Conn struct {
13
-	net.Conn
13
+	essentials.Conn
14
 
14
 
15
 	readBuffer bytes.Buffer
15
 	readBuffer bytes.Buffer
16
 }
16
 }

+ 1
- 1
mtglib/internal/faketls/conn_test.go Просмотреть файл

15
 )
15
 )
16
 
16
 
17
 type ConnMock struct {
17
 type ConnMock struct {
18
-	testlib.NetConnMock
18
+	testlib.EssentialsConnMock
19
 
19
 
20
 	readBuffer  bytes.Buffer
20
 	readBuffer  bytes.Buffer
21
 	writeBuffer bytes.Buffer
21
 	writeBuffer bytes.Buffer

+ 1
- 1
mtglib/internal/obfuscated2/client_handshake_test.go Просмотреть файл

42
 			writeData := make([]byte, len(snapshot.Encrypted.Text.data))
42
 			writeData := make([]byte, len(snapshot.Encrypted.Text.data))
43
 			readData := make([]byte, len(snapshot.Decrypted.Text.data))
43
 			readData := make([]byte, len(snapshot.Decrypted.Text.data))
44
 
44
 
45
-			connMock := &testlib.NetConnMock{}
45
+			connMock := &testlib.EssentialsConnMock{}
46
 			connMock.On("Read", mock.Anything).
46
 			connMock.On("Read", mock.Anything).
47
 				Once().
47
 				Once().
48
 				Return(len(snapshot.Decrypted.Text.data), nil).
48
 				Return(len(snapshot.Decrypted.Text.data), nil).

+ 3
- 2
mtglib/internal/obfuscated2/conn.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"crypto/cipher"
4
 	"crypto/cipher"
5
-	"net"
5
+
6
+	"github.com/9seconds/mtg/v2/essentials"
6
 )
7
 )
7
 
8
 
8
 type Conn struct {
9
 type Conn struct {
9
-	net.Conn
10
+	essentials.Conn
10
 
11
 
11
 	Encryptor cipher.Stream
12
 	Encryptor cipher.Stream
12
 	Decryptor cipher.Stream
13
 	Decryptor cipher.Stream

+ 2
- 2
mtglib/internal/obfuscated2/server_handshake_test.go Просмотреть файл

16
 type ServerHandshakeTestSuite struct {
16
 type ServerHandshakeTestSuite struct {
17
 	suite.Suite
17
 	suite.Suite
18
 
18
 
19
-	connMock  *testlib.NetConnMock
19
+	connMock  *testlib.EssentialsConnMock
20
 	proxyConn obfuscated2.Conn
20
 	proxyConn obfuscated2.Conn
21
 	encryptor cipher.Stream
21
 	encryptor cipher.Stream
22
 	decryptor cipher.Stream
22
 	decryptor cipher.Stream
24
 
24
 
25
 func (suite *ServerHandshakeTestSuite) SetupTest() {
25
 func (suite *ServerHandshakeTestSuite) SetupTest() {
26
 	buf := &bytes.Buffer{}
26
 	buf := &bytes.Buffer{}
27
-	suite.connMock = &testlib.NetConnMock{}
27
+	suite.connMock = &testlib.EssentialsConnMock{}
28
 
28
 
29
 	encryptor, decryptor, err := obfuscated2.ServerHandshake(buf)
29
 	encryptor, decryptor, err := obfuscated2.ServerHandshake(buf)
30
 	suite.NoError(err)
30
 	suite.NoError(err)

+ 23
- 9
mtglib/internal/relay/relay.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
-	"net"
5
+	"errors"
6
+	"io"
6
 	"sync"
7
 	"sync"
8
+
9
+	"github.com/9seconds/mtg/v2/essentials"
7
 )
10
 )
8
 
11
 
9
-func Relay(ctx context.Context, log Logger, telegramConn, clientConn net.Conn) {
12
+func Relay(ctx context.Context, log Logger, telegramConn, clientConn essentials.Conn) {
10
 	defer telegramConn.Close()
13
 	defer telegramConn.Close()
11
 	defer clientConn.Close()
14
 	defer clientConn.Close()
12
 
15
 
29
 	wg.Wait()
32
 	wg.Wait()
30
 }
33
 }
31
 
34
 
32
-func pump(log Logger, src, dst net.Conn, wg *sync.WaitGroup, direction string) {
33
-	defer wg.Done()
34
-
35
+func pump(log Logger, src, dst essentials.Conn, wg *sync.WaitGroup, direction string) {
35
 	syncer := acquireSyncPair(src, dst)
36
 	syncer := acquireSyncPair(src, dst)
36
-	defer releaseSyncPair(syncer)
37
-	defer syncer.Flush()
38
 
37
 
39
-	if n, err := syncer.Sync(); err != nil {
40
-		log.Printf("cannot pump %s (written %d bytes): %v", direction, n, err)
38
+	defer func() {
39
+		syncer.Flush()
40
+		releaseSyncPair(syncer)
41
+		src.CloseRead()
42
+		dst.CloseWrite()
43
+		wg.Done()
44
+	}()
45
+
46
+	n, err := syncer.Sync()
47
+
48
+	switch {
49
+	case err == nil:
50
+		log.Printf("%s has been finished", direction)
51
+	case errors.Is(err, io.EOF):
52
+		log.Printf("%s has been finished because of EOF. Written %d bytes", direction, n)
53
+	default:
54
+		log.Printf("%s has been finished (written %d bytes): %v", direction, n, err)
41
 	}
55
 	}
42
 }
56
 }

+ 10
- 4
mtglib/internal/relay/relay_test.go Просмотреть файл

17
 	loggerMock       relay.Logger
17
 	loggerMock       relay.Logger
18
 	ctx              context.Context
18
 	ctx              context.Context
19
 	ctxCancel        context.CancelFunc
19
 	ctxCancel        context.CancelFunc
20
-	telegramConnMock *testlib.NetConnMock
21
-	clientConnMock   *testlib.NetConnMock
20
+	telegramConnMock *testlib.EssentialsConnMock
21
+	clientConnMock   *testlib.EssentialsConnMock
22
 }
22
 }
23
 
23
 
24
 func (suite *RelayTestSuite) SetupTest() {
24
 func (suite *RelayTestSuite) SetupTest() {
26
 	suite.ctx = ctx
26
 	suite.ctx = ctx
27
 	suite.ctxCancel = cancel
27
 	suite.ctxCancel = cancel
28
 	suite.loggerMock = &loggerMock{}
28
 	suite.loggerMock = &loggerMock{}
29
-	suite.telegramConnMock = &testlib.NetConnMock{}
30
-	suite.clientConnMock = &testlib.NetConnMock{}
29
+	suite.telegramConnMock = &testlib.EssentialsConnMock{}
30
+	suite.clientConnMock = &testlib.EssentialsConnMock{}
31
 }
31
 }
32
 
32
 
33
 func (suite *RelayTestSuite) TearDownTest() {
33
 func (suite *RelayTestSuite) TearDownTest() {
38
 
38
 
39
 func (suite *RelayTestSuite) TestExit() {
39
 func (suite *RelayTestSuite) TestExit() {
40
 	suite.telegramConnMock.On("Close").Return(nil)
40
 	suite.telegramConnMock.On("Close").Return(nil)
41
+	suite.telegramConnMock.On("CloseRead").Return(nil).Once()
42
+	suite.telegramConnMock.On("CloseWrite").Return(nil).Once()
41
 	suite.telegramConnMock.On("Read", mock.Anything).Return(10, io.EOF).Once()
43
 	suite.telegramConnMock.On("Read", mock.Anything).Return(10, io.EOF).Once()
42
 	suite.telegramConnMock.On("Write", mock.Anything).Return(10, io.EOF).Maybe()
44
 	suite.telegramConnMock.On("Write", mock.Anything).Return(10, io.EOF).Maybe()
45
+	suite.telegramConnMock.On("SetReadDeadline", mock.Anything).Return(nil).Maybe()
43
 
46
 
44
 	suite.clientConnMock.On("Read", mock.Anything).Return(0, io.EOF).Once()
47
 	suite.clientConnMock.On("Read", mock.Anything).Return(0, io.EOF).Once()
45
 	suite.clientConnMock.On("Write", mock.Anything).Return(10, io.EOF).Maybe()
48
 	suite.clientConnMock.On("Write", mock.Anything).Return(10, io.EOF).Maybe()
46
 	suite.clientConnMock.On("Close").Return(nil)
49
 	suite.clientConnMock.On("Close").Return(nil)
50
+	suite.clientConnMock.On("CloseRead").Return(nil).Once()
51
+	suite.clientConnMock.On("CloseWrite").Return(nil).Once()
52
+	suite.clientConnMock.On("SetReadDeadline", mock.Anything).Return(nil).Maybe()
47
 
53
 
48
 	relay.Relay(suite.ctx, suite.loggerMock, suite.telegramConnMock, suite.clientConnMock)
54
 	relay.Relay(suite.ctx, suite.loggerMock, suite.telegramConnMock, suite.clientConnMock)
49
 }
55
 }

+ 1
- 1
mtglib/internal/relay/sync_pair.go Просмотреть файл

48
 	s.mutex.Lock()
48
 	s.mutex.Lock()
49
 	defer s.mutex.Unlock()
49
 	defer s.mutex.Unlock()
50
 
50
 
51
-	return s.writer.Flush()
51
+	return s.writer.Flush() // nolint: wrapcheck
52
 }
52
 }
53
 
53
 
54
 func (s *syncPair) readBlocking(p []byte, blocking bool) (int, error) {
54
 func (s *syncPair) readBlocking(p []byte, blocking bool) (int, error) {

+ 3
- 2
mtglib/internal/telegram/init.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
-	"net"
5
+
6
+	"github.com/9seconds/mtg/v2/essentials"
6
 )
7
 )
7
 
8
 
8
 type preferIP uint8
9
 type preferIP uint8
82
 )
83
 )
83
 
84
 
84
 type Dialer interface {
85
 type Dialer interface {
85
-	DialContext(ctx context.Context, network, address string) (net.Conn, error)
86
+	DialContext(ctx context.Context, network, address string) (essentials.Conn, error)
86
 }
87
 }

+ 4
- 3
mtglib/internal/telegram/telegram.go Просмотреть файл

3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
 	"fmt"
5
 	"fmt"
6
-	"net"
7
 	"strings"
6
 	"strings"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
8
 )
9
 )
9
 
10
 
10
 type Telegram struct {
11
 type Telegram struct {
13
 	pool     addressPool
14
 	pool     addressPool
14
 }
15
 }
15
 
16
 
16
-func (t Telegram) Dial(ctx context.Context, dc int) (net.Conn, error) {
17
+func (t Telegram) Dial(ctx context.Context, dc int) (essentials.Conn, error) {
17
 	var addresses []tgAddr
18
 	var addresses []tgAddr
18
 
19
 
19
 	switch t.preferIP {
20
 	switch t.preferIP {
28
 	}
29
 	}
29
 
30
 
30
 	var (
31
 	var (
31
-		conn net.Conn
32
+		conn essentials.Conn
32
 		err  error
33
 		err  error
33
 	)
34
 	)
34
 
35
 

+ 3
- 2
mtglib/proxy.go Просмотреть файл

9
 	"sync"
9
 	"sync"
10
 	"time"
10
 	"time"
11
 
11
 
12
+	"github.com/9seconds/mtg/v2/essentials"
12
 	"github.com/9seconds/mtg/v2/mtglib/internal/faketls"
13
 	"github.com/9seconds/mtg/v2/mtglib/internal/faketls"
13
 	"github.com/9seconds/mtg/v2/mtglib/internal/faketls/record"
14
 	"github.com/9seconds/mtg/v2/mtglib/internal/faketls/record"
14
 	"github.com/9seconds/mtg/v2/mtglib/internal/obfuscated2"
15
 	"github.com/9seconds/mtg/v2/mtglib/internal/obfuscated2"
44
 
45
 
45
 // ServeConn serves a connection. We do not check IP blocklist and
46
 // ServeConn serves a connection. We do not check IP blocklist and
46
 // concurrency limit here.
47
 // concurrency limit here.
47
-func (p *Proxy) ServeConn(conn net.Conn) {
48
+func (p *Proxy) ServeConn(conn essentials.Conn) {
48
 	p.streamWaitGroup.Add(1)
49
 	p.streamWaitGroup.Add(1)
49
 	defer p.streamWaitGroup.Done()
50
 	defer p.streamWaitGroup.Done()
50
 
51
 
299
 
300
 
300
 	pool, err := ants.NewPoolWithFunc(opts.getConcurrency(),
301
 	pool, err := ants.NewPoolWithFunc(opts.getConcurrency(),
301
 		func(arg interface{}) {
302
 		func(arg interface{}) {
302
-			proxy.ServeConn(arg.(net.Conn))
303
+			proxy.ServeConn(arg.(essentials.Conn))
303
 		},
304
 		},
304
 		ants.WithLogger(opts.getLogger("ants")),
305
 		ants.WithLogger(opts.getLogger("ants")),
305
 		ants.WithNonblocking(true))
306
 		ants.WithNonblocking(true))

+ 5
- 3
mtglib/stream_context.go Просмотреть файл

6
 	"encoding/base64"
6
 	"encoding/base64"
7
 	"net"
7
 	"net"
8
 	"time"
8
 	"time"
9
+
10
+	"github.com/9seconds/mtg/v2/essentials"
9
 )
11
 )
10
 
12
 
11
 type streamContext struct {
13
 type streamContext struct {
12
 	ctx          context.Context
14
 	ctx          context.Context
13
 	ctxCancel    context.CancelFunc
15
 	ctxCancel    context.CancelFunc
14
-	clientConn   net.Conn
15
-	telegramConn net.Conn
16
+	clientConn   essentials.Conn
17
+	telegramConn essentials.Conn
16
 	streamID     string
18
 	streamID     string
17
 	dc           int
19
 	dc           int
18
 	logger       Logger
20
 	logger       Logger
50
 	return s.clientConn.RemoteAddr().(*net.TCPAddr).IP
52
 	return s.clientConn.RemoteAddr().(*net.TCPAddr).IP
51
 }
53
 }
52
 
54
 
53
-func newStreamContext(ctx context.Context, logger Logger, clientConn net.Conn) *streamContext {
55
+func newStreamContext(ctx context.Context, logger Logger, clientConn essentials.Conn) *streamContext {
54
 	connIDBytes := make([]byte, ConnectionIDBytesLength)
56
 	connIDBytes := make([]byte, ConnectionIDBytesLength)
55
 
57
 
56
 	if _, err := rand.Read(connIDBytes); err != nil {
58
 	if _, err := rand.Read(connIDBytes); err != nil {

+ 3
- 3
mtglib/stream_context_internal_test.go Просмотреть файл

12
 type StreamContextTestSuite struct {
12
 type StreamContextTestSuite struct {
13
 	suite.Suite
13
 	suite.Suite
14
 
14
 
15
-	connMock  *testlib.NetConnMock
15
+	connMock  *testlib.EssentialsConnMock
16
 	logger    NoopLogger
16
 	logger    NoopLogger
17
 	ctx       *streamContext
17
 	ctx       *streamContext
18
 	ctxCancel context.CancelFunc
18
 	ctxCancel context.CancelFunc
27
 	ctx = context.WithValue(ctx, "key", "value") // nolint: golint, revive, staticcheck
27
 	ctx = context.WithValue(ctx, "key", "value") // nolint: golint, revive, staticcheck
28
 
28
 
29
 	suite.ctxCancel = cancel
29
 	suite.ctxCancel = cancel
30
-	suite.connMock = &testlib.NetConnMock{}
30
+	suite.connMock = &testlib.EssentialsConnMock{}
31
 
31
 
32
 	addr := &net.TCPAddr{
32
 	addr := &net.TCPAddr{
33
 		IP:   net.ParseIP("10.0.0.10"),
33
 		IP:   net.ParseIP("10.0.0.10"),
73
 func (suite *StreamContextTestSuite) TestClose() {
73
 func (suite *StreamContextTestSuite) TestClose() {
74
 	suite.connMock.On("Close").Once().Return(nil)
74
 	suite.connMock.On("Close").Once().Return(nil)
75
 
75
 
76
-	tgConnMock := &testlib.NetConnMock{}
76
+	tgConnMock := &testlib.EssentialsConnMock{}
77
 	tgConnMock.On("Close").Once().Return(nil)
77
 	tgConnMock.On("Close").Once().Return(nil)
78
 
78
 
79
 	suite.ctx.telegramConn = tgConnMock
79
 	suite.ctx.telegramConn = tgConnMock

+ 7
- 5
network/circuit_breaker.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
-	"net"
6
 	"sync/atomic"
5
 	"sync/atomic"
7
 	"time"
6
 	"time"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
8
 )
9
 )
9
 
10
 
10
 const (
11
 const (
30
 	resetFailuresTimeout time.Duration
31
 	resetFailuresTimeout time.Duration
31
 }
32
 }
32
 
33
 
33
-func (c *circuitBreakerDialer) Dial(network, address string) (net.Conn, error) {
34
+func (c *circuitBreakerDialer) Dial(network, address string) (essentials.Conn, error) {
34
 	return c.DialContext(context.Background(), network, address)
35
 	return c.DialContext(context.Background(), network, address)
35
 }
36
 }
36
 
37
 
37
 func (c *circuitBreakerDialer) DialContext(ctx context.Context,
38
 func (c *circuitBreakerDialer) DialContext(ctx context.Context,
38
-	network, address string) (net.Conn, error) {
39
+	network, address string) (essentials.Conn, error) {
39
 	switch atomic.LoadUint32(&c.state) {
40
 	switch atomic.LoadUint32(&c.state) {
40
 	case circuitBreakerStateClosed:
41
 	case circuitBreakerStateClosed:
41
 		return c.doClosed(ctx, network, address)
42
 		return c.doClosed(ctx, network, address)
47
 }
48
 }
48
 
49
 
49
 func (c *circuitBreakerDialer) doClosed(ctx context.Context,
50
 func (c *circuitBreakerDialer) doClosed(ctx context.Context,
50
-	network, address string) (net.Conn, error) {
51
+	network, address string) (essentials.Conn, error) {
51
 	conn, err := c.Dialer.DialContext(ctx, network, address)
52
 	conn, err := c.Dialer.DialContext(ctx, network, address)
52
 
53
 
53
 	select {
54
 	select {
78
 	return conn, err // nolint: wrapcheck
79
 	return conn, err // nolint: wrapcheck
79
 }
80
 }
80
 
81
 
81
-func (c *circuitBreakerDialer) doHalfOpened(ctx context.Context, network, address string) (net.Conn, error) {
82
+func (c *circuitBreakerDialer) doHalfOpened(ctx context.Context,
83
+	network, address string) (essentials.Conn, error) {
82
 	if !atomic.CompareAndSwapUint32(&c.halfOpenAttempts, 0, 1) {
84
 	if !atomic.CompareAndSwapUint32(&c.halfOpenAttempts, 0, 1) {
83
 		return nil, ErrCircuitBreakerOpened
85
 		return nil, ErrCircuitBreakerOpened
84
 	}
86
 	}

+ 2
- 2
network/circuit_breaker_internal_test.go Просмотреть файл

21
 	mutex          sync.Mutex
21
 	mutex          sync.Mutex
22
 	ctx            context.Context
22
 	ctx            context.Context
23
 	ctxCancel      context.CancelFunc
23
 	ctxCancel      context.CancelFunc
24
-	connMock       *testlib.NetConnMock
24
+	connMock       *testlib.EssentialsConnMock
25
 	baseDialerMock *DialerMock
25
 	baseDialerMock *DialerMock
26
 }
26
 }
27
 
27
 
29
 	suite.mutex = sync.Mutex{}
29
 	suite.mutex = sync.Mutex{}
30
 	suite.ctx, suite.ctxCancel = context.WithCancel(context.Background())
30
 	suite.ctx, suite.ctxCancel = context.WithCancel(context.Background())
31
 	suite.baseDialerMock = &DialerMock{}
31
 	suite.baseDialerMock = &DialerMock{}
32
-	suite.connMock = &testlib.NetConnMock{}
32
+	suite.connMock = &testlib.EssentialsConnMock{}
33
 	suite.d = newCircuitBreakerDialer(suite.baseDialerMock,
33
 	suite.d = newCircuitBreakerDialer(suite.baseDialerMock,
34
 		3, 100*time.Millisecond, 50*time.Millisecond)
34
 		3, 100*time.Millisecond, 50*time.Millisecond)
35
 }
35
 }

+ 5
- 3
network/default.go Просмотреть файл

5
 	"fmt"
5
 	"fmt"
6
 	"net"
6
 	"net"
7
 	"time"
7
 	"time"
8
+
9
+	"github.com/9seconds/mtg/v2/essentials"
8
 )
10
 )
9
 
11
 
10
 type defaultDialer struct {
12
 type defaultDialer struct {
11
 	net.Dialer
13
 	net.Dialer
12
 }
14
 }
13
 
15
 
14
-func (d *defaultDialer) Dial(network, address string) (net.Conn, error) {
16
+func (d *defaultDialer) Dial(network, address string) (essentials.Conn, error) {
15
 	return d.DialContext(context.Background(), network, address)
17
 	return d.DialContext(context.Background(), network, address)
16
 }
18
 }
17
 
19
 
18
-func (d *defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
20
+func (d *defaultDialer) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
19
 	switch network {
21
 	switch network {
20
 	case "tcp", "tcp4", "tcp6": // nolint: goconst
22
 	case "tcp", "tcp4", "tcp6": // nolint: goconst
21
 	default:
23
 	default:
34
 		return nil, fmt.Errorf("cannot set socket options: %w", err)
36
 		return nil, fmt.Errorf("cannot set socket options: %w", err)
35
 	}
37
 	}
36
 
38
 
37
-	return conn, nil
39
+	return conn.(essentials.Conn), nil
38
 }
40
 }
39
 
41
 
40
 // NewDefaultDialer build a new dialer which dials bypassing proxies
42
 // NewDefaultDialer build a new dialer which dials bypassing proxies

+ 4
- 3
network/init.go Просмотреть файл

20
 import (
20
 import (
21
 	"context"
21
 	"context"
22
 	"errors"
22
 	"errors"
23
-	"net"
24
 	"time"
23
 	"time"
24
+
25
+	"github.com/9seconds/mtg/v2/essentials"
25
 )
26
 )
26
 
27
 
27
 const (
28
 const (
95
 // Dialer defines an interface which is required to bootstrap a network
96
 // Dialer defines an interface which is required to bootstrap a network
96
 // instance from.
97
 // instance from.
97
 type Dialer interface {
98
 type Dialer interface {
98
-	Dial(network, address string) (net.Conn, error)
99
-	DialContext(ctx context.Context, network, address string) (net.Conn, error)
99
+	Dial(network, address string) (essentials.Conn, error)
100
+	DialContext(ctx context.Context, network, address string) (essentials.Conn, error)
100
 }
101
 }

+ 5
- 5
network/init_internal_test.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
-	"net"
6
 
5
 
6
+	"github.com/9seconds/mtg/v2/essentials"
7
 	"github.com/stretchr/testify/mock"
7
 	"github.com/stretchr/testify/mock"
8
 )
8
 )
9
 
9
 
11
 	mock.Mock
11
 	mock.Mock
12
 }
12
 }
13
 
13
 
14
-func (d *DialerMock) Dial(network, address string) (net.Conn, error) {
14
+func (d *DialerMock) Dial(network, address string) (essentials.Conn, error) {
15
 	args := d.Called(network, address)
15
 	args := d.Called(network, address)
16
 
16
 
17
-	return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck
17
+	return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck
18
 }
18
 }
19
 
19
 
20
-func (d *DialerMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
20
+func (d *DialerMock) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
21
 	args := d.Called(ctx, network, address)
21
 	args := d.Called(ctx, network, address)
22
 
22
 
23
-	return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck
23
+	return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck
24
 }
24
 }

+ 8
- 5
network/init_test.go Просмотреть файл

8
 	"net/url"
8
 	"net/url"
9
 	"strings"
9
 	"strings"
10
 
10
 
11
+	"github.com/9seconds/mtg/v2/essentials"
11
 	"github.com/9seconds/mtg/v2/network"
12
 	"github.com/9seconds/mtg/v2/network"
12
 	socks5 "github.com/armon/go-socks5"
13
 	socks5 "github.com/armon/go-socks5"
13
 	"github.com/mccutchen/go-httpbin/httpbin"
14
 	"github.com/mccutchen/go-httpbin/httpbin"
18
 	mock.Mock
19
 	mock.Mock
19
 }
20
 }
20
 
21
 
21
-func (d *DialerMock) Dial(network, address string) (net.Conn, error) {
22
+func (d *DialerMock) Dial(network, address string) (essentials.Conn, error) {
22
 	args := d.Called(network, address)
23
 	args := d.Called(network, address)
23
 
24
 
24
-	return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck
25
+	return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck
25
 }
26
 }
26
 
27
 
27
-func (d *DialerMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
28
+func (d *DialerMock) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
28
 	args := d.Called(ctx, network, address)
29
 	args := d.Called(ctx, network, address)
29
 
30
 
30
-	return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck
31
+	return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck
31
 }
32
 }
32
 
33
 
33
 type HTTPServerTestSuite struct {
34
 type HTTPServerTestSuite struct {
53
 func (suite *HTTPServerTestSuite) MakeHTTPClient(dialer network.Dialer) *http.Client {
54
 func (suite *HTTPServerTestSuite) MakeHTTPClient(dialer network.Dialer) *http.Client {
54
 	return &http.Client{
55
 	return &http.Client{
55
 		Transport: &http.Transport{
56
 		Transport: &http.Transport{
56
-			DialContext: dialer.DialContext,
57
+			DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
58
+				return dialer.DialContext(ctx, network, address) // nolint: wrapcheck
59
+			},
57
 		},
60
 		},
58
 	}
61
 	}
59
 }
62
 }

+ 4
- 3
network/load_balanced_socks5.go Просмотреть файл

4
 	"context"
4
 	"context"
5
 	"fmt"
5
 	"fmt"
6
 	"math/rand"
6
 	"math/rand"
7
-	"net"
8
 	"net/url"
7
 	"net/url"
8
+
9
+	"github.com/9seconds/mtg/v2/essentials"
9
 )
10
 )
10
 
11
 
11
 type loadBalancedSocks5Dialer struct {
12
 type loadBalancedSocks5Dialer struct {
12
 	dialers []Dialer
13
 	dialers []Dialer
13
 }
14
 }
14
 
15
 
15
-func (l loadBalancedSocks5Dialer) Dial(network, address string) (net.Conn, error) {
16
+func (l loadBalancedSocks5Dialer) Dial(network, address string) (essentials.Conn, error) {
16
 	return l.DialContext(context.Background(), network, address)
17
 	return l.DialContext(context.Background(), network, address)
17
 }
18
 }
18
 
19
 
19
-func (l loadBalancedSocks5Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
20
+func (l loadBalancedSocks5Dialer) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
20
 	length := len(l.dialers)
21
 	length := len(l.dialers)
21
 	start := rand.Intn(length)
22
 	start := rand.Intn(length)
22
 	moved := false
23
 	moved := false

+ 10
- 6
network/network.go Просмотреть файл

9
 	"sync"
9
 	"sync"
10
 	"time"
10
 	"time"
11
 
11
 
12
+	"github.com/9seconds/mtg/v2/essentials"
12
 	"github.com/9seconds/mtg/v2/mtglib"
13
 	"github.com/9seconds/mtg/v2/mtglib"
13
 )
14
 )
14
 
15
 
30
 	dns         *dnsResolver
31
 	dns         *dnsResolver
31
 }
32
 }
32
 
33
 
33
-func (n *network) Dial(protocol, address string) (net.Conn, error) {
34
+func (n *network) Dial(protocol, address string) (essentials.Conn, error) {
34
 	return n.DialContext(context.Background(), protocol, address)
35
 	return n.DialContext(context.Background(), protocol, address)
35
 }
36
 }
36
 
37
 
37
-func (n *network) DialContext(ctx context.Context, protocol, address string) (net.Conn, error) {
38
+func (n *network) DialContext(ctx context.Context, protocol, address string) (essentials.Conn, error) {
38
 	host, port, _ := net.SplitHostPort(address)
39
 	host, port, _ := net.SplitHostPort(address)
39
 
40
 
40
 	ips, err := n.dnsResolve(protocol, host)
41
 	ips, err := n.dnsResolve(protocol, host)
46
 		ips[i], ips[j] = ips[j], ips[i]
47
 		ips[i], ips[j] = ips[j], ips[i]
47
 	})
48
 	})
48
 
49
 
49
-	var conn net.Conn
50
+	var conn essentials.Conn
51
+
50
 	for _, v := range ips {
52
 	for _, v := range ips {
51
 		conn, err = n.dialer.DialContext(ctx, protocol, net.JoinHostPort(v, port))
53
 		conn, err = n.dialer.DialContext(ctx, protocol, net.JoinHostPort(v, port))
52
 
54
 
59
 }
61
 }
60
 
62
 
61
 func (n *network) MakeHTTPClient(dialFunc func(ctx context.Context,
63
 func (n *network) MakeHTTPClient(dialFunc func(ctx context.Context,
62
-	network, address string) (net.Conn, error)) *http.Client {
64
+	network, address string) (essentials.Conn, error)) *http.Client {
63
 	if dialFunc == nil {
65
 	if dialFunc == nil {
64
 		dialFunc = n.DialContext
66
 		dialFunc = n.DialContext
65
 	}
67
 	}
144
 
146
 
145
 func makeHTTPClient(userAgent string,
147
 func makeHTTPClient(userAgent string,
146
 	timeout time.Duration,
148
 	timeout time.Duration,
147
-	dialFunc func(ctx context.Context, network, address string) (net.Conn, error)) *http.Client {
149
+	dialFunc func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client {
148
 	return &http.Client{
150
 	return &http.Client{
149
 		Timeout: timeout,
151
 		Timeout: timeout,
150
 		Transport: networkHTTPTransport{
152
 		Transport: networkHTTPTransport{
151
 			userAgent: userAgent,
153
 			userAgent: userAgent,
152
 			next: &http.Transport{
154
 			next: &http.Transport{
153
-				DialContext: dialFunc,
155
+				DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
156
+					return dialFunc(ctx, network, address)
157
+				},
154
 			},
158
 			},
155
 		},
159
 		},
156
 	}
160
 	}

+ 146
- 5
network/socks5.go Просмотреть файл

1
 package network
1
 package network
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"fmt"
5
 	"fmt"
6
+	"io"
7
+	"net"
5
 	"net/url"
8
 	"net/url"
6
 
9
 
7
-	"golang.org/x/net/proxy"
10
+	"github.com/9seconds/mtg/v2/essentials"
11
+	"github.com/txthinking/socks5"
8
 )
12
 )
9
 
13
 
14
+type socks5Dialer struct {
15
+	Dialer
16
+
17
+	username     []byte
18
+	password     []byte
19
+	proxyAddress string
20
+}
21
+
22
+func (s socks5Dialer) Dial(network, address string) (essentials.Conn, error) {
23
+	return s.DialContext(context.Background(), network, address)
24
+}
25
+
26
+func (s socks5Dialer) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
27
+	switch network {
28
+	case "tcp", "tcp4", "tcp6":
29
+	default:
30
+		return nil, fmt.Errorf("%s network type is not supported", network)
31
+	}
32
+
33
+	conn, err := s.Dialer.DialContext(ctx, network, s.proxyAddress)
34
+	if err != nil {
35
+		return nil, fmt.Errorf("cannot dial to the proxy: %w", err)
36
+	}
37
+
38
+	if err := s.handshake(conn); err != nil {
39
+		conn.Close()
40
+
41
+		return nil, fmt.Errorf("cannot perform a handshake: %w", err)
42
+	}
43
+
44
+	if err := s.connect(conn, address); err != nil {
45
+		conn.Close()
46
+
47
+		return nil, fmt.Errorf("cannot connect to a destination host %s: %w", address, err)
48
+	}
49
+
50
+	return conn, nil
51
+}
52
+
53
+func (s socks5Dialer) handshake(conn io.ReadWriter) error {
54
+	authMethod := socks5.MethodUsernamePassword
55
+	if len(s.username)+len(s.password) == 0 {
56
+		authMethod = socks5.MethodNone
57
+	}
58
+
59
+	if err := s.handshakeNegotiation(conn, authMethod); err != nil {
60
+		return fmt.Errorf("cannot perform negotiation: %w", err)
61
+	}
62
+
63
+	if authMethod == socks5.MethodNone {
64
+		return nil
65
+	}
66
+
67
+	if err := s.handshakeAuth(conn); err != nil {
68
+		return fmt.Errorf("cannot authenticate: %w", err)
69
+	}
70
+
71
+	return nil
72
+}
73
+
74
+func (s socks5Dialer) handshakeNegotiation(conn io.ReadWriter, authMethod byte) error {
75
+	request := socks5.NewNegotiationRequest([]byte{authMethod})
76
+	if _, err := request.WriteTo(conn); err != nil {
77
+		return fmt.Errorf("cannot send request: %w", err)
78
+	}
79
+
80
+	response, err := socks5.NewNegotiationReplyFrom(conn)
81
+	if err != nil {
82
+		return fmt.Errorf("cannot read response: %w", err)
83
+	}
84
+
85
+	if response.Method != authMethod {
86
+		return fmt.Errorf("%v is unsupported auth method", authMethod)
87
+	}
88
+
89
+	return nil
90
+}
91
+
92
+func (s socks5Dialer) handshakeAuth(conn io.ReadWriter) error {
93
+	request := socks5.NewUserPassNegotiationRequest(s.username, s.password)
94
+
95
+	if _, err := request.WriteTo(conn); err != nil {
96
+		return fmt.Errorf("cannot send a request: %w", err)
97
+	}
98
+
99
+	response, err := socks5.NewUserPassNegotiationReplyFrom(conn)
100
+	if err != nil {
101
+		return fmt.Errorf("cannot read a response: %w", err)
102
+	}
103
+
104
+	if response.Status != socks5.UserPassStatusSuccess {
105
+		return fmt.Errorf("authenticate has failed: %v", response.Status)
106
+	}
107
+
108
+	return nil
109
+}
110
+
111
+func (s socks5Dialer) connect(conn io.ReadWriter, address string) error {
112
+	addrType, host, port, err := socks5.ParseAddress(address)
113
+	if err != nil {
114
+		return fmt.Errorf("cannot parse address: %w", err)
115
+	}
116
+
117
+	if addrType == socks5.ATYPDomain {
118
+		host = host[1:]
119
+	}
120
+
121
+	request := socks5.NewRequest(socks5.CmdConnect, addrType, host, port)
122
+
123
+	if _, err := request.WriteTo(conn); err != nil {
124
+		return fmt.Errorf("cannot send a request: %w", err)
125
+	}
126
+
127
+	response, err := socks5.NewReplyFrom(conn)
128
+	if err != nil {
129
+		return fmt.Errorf("cannot read a response: %w", err)
130
+	}
131
+
132
+	if response.Rep != socks5.RepSuccess {
133
+		return fmt.Errorf("unsuccessful request: %v", response.Rep)
134
+	}
135
+
136
+	return nil
137
+}
138
+
10
 // NewSocks5Dialer build a new dialer from a given one (so, in theory
139
 // NewSocks5Dialer build a new dialer from a given one (so, in theory
11
 // you can chain here). Proxy parameters are passed with URI in a form of:
140
 // you can chain here). Proxy parameters are passed with URI in a form of:
12
 //
141
 //
13
 //     socks5://[user:[password]]@host:port
142
 //     socks5://[user:[password]]@host:port
14
 func NewSocks5Dialer(baseDialer Dialer, proxyURL *url.URL) (Dialer, error) {
143
 func NewSocks5Dialer(baseDialer Dialer, proxyURL *url.URL) (Dialer, error) {
15
-	rv, err := proxy.FromURL(proxyURL, baseDialer)
16
-	if err != nil {
17
-		return nil, fmt.Errorf("cannot initialize socks5 proxy dialer: %w", err)
144
+	if _, _, err := net.SplitHostPort(proxyURL.Host); err != nil {
145
+		return nil, fmt.Errorf("incorrect url %s", proxyURL.Redacted())
146
+	}
147
+
148
+	dialer := socks5Dialer{
149
+		Dialer:       baseDialer,
150
+		proxyAddress: proxyURL.Host,
151
+	}
152
+
153
+	if proxyURL.User != nil {
154
+		password, isSet := proxyURL.User.Password()
155
+		if isSet {
156
+			dialer.username = []byte(proxyURL.User.Username())
157
+			dialer.password = []byte(password)
158
+		}
18
 	}
159
 	}
19
 
160
 
20
-	return rv.(Dialer), nil
161
+	return dialer, nil
21
 }
162
 }

+ 1
- 1
network/socks5_test.go Просмотреть файл

55
 	suite.Equal(http.StatusOK, resp.StatusCode)
55
 	suite.Equal(http.StatusOK, resp.StatusCode)
56
 }
56
 }
57
 
57
 
58
-func TestSocks5TestSuite(t *testing.T) {
58
+func TestSocks5(t *testing.T) {
59
 	t.Parallel()
59
 	t.Parallel()
60
 	suite.Run(t, &Socks5TestSuite{})
60
 	suite.Run(t, &Socks5TestSuite{})
61
 }
61
 }

Загрузка…
Отмена
Сохранить