Bläddra i källkod

Add v2 network package

tags/v2.1.13
9seconds 2 månader sedan
förälder
incheckning
42927c8bdc

+ 25
- 0
essentials/conns.go Visa fil

@@ -24,3 +24,28 @@ type Conn interface {
24 24
 	CloseableReader
25 25
 	CloseableWriter
26 26
 }
27
+
28
+type netConnWrapper struct {
29
+	net.Conn
30
+}
31
+
32
+func (n netConnWrapper) CloseRead() error {
33
+	if conn, ok := n.Conn.(CloseableReader); ok {
34
+		return conn.CloseRead()
35
+	}
36
+
37
+	return n.Close()
38
+}
39
+
40
+func (n netConnWrapper) CloseWrite() error {
41
+	if conn, ok := n.Conn.(CloseableWriter); ok {
42
+		return conn.CloseWrite()
43
+	}
44
+
45
+	return n.Close()
46
+}
47
+
48
+// WrapConn wraps a generic [net.Conn] into Conn.
49
+func WrapNetConn(conn net.Conn) Conn {
50
+	return netConnWrapper{conn}
51
+}

+ 1
- 0
go.mod Visa fil

@@ -46,6 +46,7 @@ require (
46 46
 	github.com/pmezard/go-difflib v1.0.0 // indirect
47 47
 	github.com/prometheus/client_model v0.6.2 // indirect
48 48
 	github.com/rogpeppe/go-internal v1.14.1 // indirect
49
+	github.com/things-go/go-socks5 v0.1.0 // indirect
49 50
 	github.com/txthinking/runnergroup v0.0.0-20250224021307-5864ffeb65ae // indirect
50 51
 	go.yaml.in/yaml/v2 v2.4.3 // indirect
51 52
 	golang.org/x/sync v0.19.0 // indirect

+ 2
- 0
go.sum Visa fil

@@ -89,6 +89,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
89 89
 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
90 90
 github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
91 91
 github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
92
+github.com/things-go/go-socks5 v0.1.0 h1:4f5dz0iMQ6cA4wseFmyLmCHmg3SWJTW92ndrKS6oERg=
93
+github.com/things-go/go-socks5 v0.1.0/go.mod h1:Riabiyu52kLsla0YmJqunt1c1JEl6iXSr4bRd7swFEA=
92 94
 github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf/go.mod h1:CLUSJbazqETbaR+i0YAhXBICV9TrKH93pziccMhmhpM=
93 95
 github.com/txthinking/runnergroup v0.0.0-20250224021307-5864ffeb65ae h1:ArVM1jICfm7g4E4dBet+KHUFMLuxmj1Nxdp/tr3ByCU=
94 96
 github.com/txthinking/runnergroup v0.0.0-20250224021307-5864ffeb65ae/go.mod h1:cldYm15/XHcGt7ndItnEWHwFZo7dinU+2QoyjfErhsI=

+ 49
- 0
network/v2/base_http_test.go Visa fil

@@ -0,0 +1,49 @@
1
+package network_test
2
+
3
+import (
4
+	"io"
5
+	"net/http"
6
+	"net/http/httptest"
7
+	"testing"
8
+
9
+	"github.com/9seconds/mtg/v2/network/v2"
10
+	"github.com/stretchr/testify/suite"
11
+)
12
+
13
+type BaseHTTPTestSuite struct {
14
+	suite.Suite
15
+
16
+	http   *httptest.Server
17
+	client *http.Client
18
+}
19
+
20
+func (suite *BaseHTTPTestSuite) SetupSuite() {
21
+	suite.http = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22
+		w.WriteHeader(http.StatusOK)
23
+		w.Write([]byte(r.Header.Get("User-Agent")))
24
+	}))
25
+}
26
+
27
+func (suite *BaseHTTPTestSuite) SetupTest() {
28
+	suite.client = network.New(nil, "mtg/1", 0, 0, 0).MakeHTTPClient(nil)
29
+}
30
+
31
+func (suite *BaseHTTPTestSuite) TestGet() {
32
+	resp, err := suite.client.Get(suite.http.URL)
33
+	suite.NoError(err)
34
+
35
+	defer resp.Body.Close()
36
+
37
+	data, err := io.ReadAll(resp.Body)
38
+	suite.NoError(err)
39
+	suite.Equal("mtg/1", string(data))
40
+}
41
+
42
+func (suite *BaseHTTPTestSuite) TearDownSuite() {
43
+	suite.http.Close()
44
+}
45
+
46
+func TestBaseHTTP(t *testing.T) {
47
+	t.Parallel()
48
+	suite.Run(t, &BaseHTTPTestSuite{})
49
+}

+ 83
- 0
network/v2/base_network_test.go Visa fil

@@ -0,0 +1,83 @@
1
+package network_test
2
+
3
+import (
4
+	"context"
5
+	"testing"
6
+
7
+	"github.com/9seconds/mtg/v2/network/v2"
8
+	"github.com/stretchr/testify/assert"
9
+	"github.com/stretchr/testify/suite"
10
+)
11
+
12
+type BaseNetworkTestSuite struct {
13
+	EchoServerTestSuite
14
+
15
+	net network.Network
16
+}
17
+
18
+func (suite *BaseNetworkTestSuite) SetupSuite() {
19
+	suite.EchoServerTestSuite.SetupSuite()
20
+
21
+	suite.net = network.New(nil, "agent", 0, 0, 0)
22
+}
23
+
24
+func (suite *BaseNetworkTestSuite) TestDialUnknownNetwork() {
25
+	testData := []string{
26
+		"udp",
27
+		"udp4",
28
+		"udp6",
29
+		"unix",
30
+	}
31
+
32
+	for _, name := range testData {
33
+		suite.T().Run(name, func(t *testing.T) {
34
+			_, err := suite.net.Dial(name, suite.EchoServerAddr())
35
+			assert.Error(t, err)
36
+		})
37
+	}
38
+}
39
+
40
+func (suite *BaseNetworkTestSuite) TestDial() {
41
+	conn, err := suite.net.Dial("tcp4", suite.EchoServerAddr())
42
+	suite.NoError(err)
43
+
44
+	buf := []byte{1, 2, 3, 4, 5}
45
+	n, err := conn.Write(buf)
46
+	suite.Equal(5, n)
47
+	suite.NoError(err)
48
+
49
+	another := make([]byte, len(buf))
50
+	n, err = conn.Read(another)
51
+	suite.NoError(err)
52
+	suite.Equal(len(another), n)
53
+	suite.Equal(buf, another)
54
+}
55
+
56
+func (suite *BaseNetworkTestSuite) TestDialContextOk() {
57
+	conn, err := suite.net.DialContext(context.Background(), "tcp4", suite.EchoServerAddr())
58
+	suite.NoError(err)
59
+
60
+	buf := []byte{1, 2, 3, 4, 5}
61
+	n, err := conn.Write(buf)
62
+	suite.Equal(5, n)
63
+	suite.NoError(err)
64
+
65
+	another := make([]byte, len(buf))
66
+	n, err = conn.Read(another)
67
+	suite.NoError(err)
68
+	suite.Equal(len(another), n)
69
+	suite.Equal(buf, another)
70
+}
71
+
72
+func (suite *BaseNetworkTestSuite) TestDialContextClosed() {
73
+	ctx, cancel := context.WithCancel(context.Background())
74
+	cancel()
75
+
76
+	_, err := suite.net.DialContext(ctx, "tcp4", suite.EchoServerAddr())
77
+	suite.ErrorIs(err, ctx.Err())
78
+}
79
+
80
+func TestNetworkBase(t *testing.T) {
81
+	t.Parallel()
82
+	suite.Run(t, &BaseNetworkTestSuite{})
83
+}

+ 106
- 0
network/v2/echo_server_test.go Visa fil

@@ -0,0 +1,106 @@
1
+package network_test
2
+
3
+import (
4
+	"context"
5
+	"io"
6
+	"net"
7
+	"sync"
8
+
9
+	"github.com/stretchr/testify/require"
10
+	"github.com/stretchr/testify/suite"
11
+)
12
+
13
+type EchoServer struct {
14
+	wg        sync.WaitGroup
15
+	ctx       context.Context
16
+	ctxCancel context.CancelFunc
17
+	listener  net.Listener
18
+}
19
+
20
+func (e *EchoServer) Run() {
21
+	e.wg.Go(func() {
22
+		<-e.ctx.Done()
23
+		e.listener.Close()
24
+	})
25
+
26
+	e.wg.Go(func() {
27
+		for {
28
+			conn, err := e.listener.Accept()
29
+			if err != nil {
30
+				return
31
+			}
32
+
33
+			e.wg.Go(func() {
34
+				<-e.ctx.Done()
35
+				conn.Close()
36
+			})
37
+			e.wg.Go(func() {
38
+				e.process(conn)
39
+			})
40
+		}
41
+	})
42
+}
43
+
44
+func (e *EchoServer) Stop() {
45
+	e.ctxCancel()
46
+	e.wg.Wait()
47
+}
48
+
49
+func (e *EchoServer) Addr() string {
50
+	return e.listener.Addr().String()
51
+}
52
+
53
+func (e *EchoServer) process(conn io.ReadWriter) {
54
+	buf := [4096]byte{}
55
+
56
+	for {
57
+		select {
58
+		case <-e.ctx.Done():
59
+			return
60
+		default:
61
+		}
62
+
63
+		n, err := conn.Read(buf[:])
64
+		if err != nil {
65
+			return
66
+		}
67
+
68
+		select {
69
+		case <-e.ctx.Done():
70
+			return
71
+		default:
72
+		}
73
+
74
+		if _, err = conn.Write(buf[:n]); err != nil {
75
+			return
76
+		}
77
+	}
78
+}
79
+
80
+type EchoServerTestSuite struct {
81
+	suite.Suite
82
+
83
+	echoServer *EchoServer
84
+}
85
+
86
+func (suite *EchoServerTestSuite) SetupSuite() {
87
+	ctx, cancel := context.WithCancel(context.Background())
88
+
89
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
90
+	require.NoError(suite.T(), err)
91
+
92
+	suite.echoServer = &EchoServer{
93
+		ctx:       ctx,
94
+		ctxCancel: cancel,
95
+		listener:  listener,
96
+	}
97
+	suite.echoServer.Run()
98
+}
99
+
100
+func (suite *EchoServerTestSuite) TearDownSuite() {
101
+	suite.echoServer.Stop()
102
+}
103
+
104
+func (suite *EchoServerTestSuite) EchoServerAddr() string {
105
+	return suite.echoServer.Addr()
106
+}

+ 14
- 0
network/v2/http.go Visa fil

@@ -0,0 +1,14 @@
1
+package network
2
+
3
+import "net/http"
4
+
5
+type networkHTTPTransport struct {
6
+	userAgent string
7
+	next      http.RoundTripper
8
+}
9
+
10
+func (n networkHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
11
+	req.Header.Set("User-Agent", n.userAgent)
12
+
13
+	return n.next.RoundTrip(req) //nolint: wrapcheck
14
+}

+ 45
- 0
network/v2/init.go Visa fil

@@ -0,0 +1,45 @@
1
+// Network contains a default implementation of the network.
2
+//
3
+// Please see [mtglib.Network] interface to get some basic idea behind this
4
+// abstraction.
5
+//
6
+// This implementation is more simple that v1 because life shows that all
7
+// this complexity, especially around circuit breakers and DoH is not really
8
+// required. There is no chance that if DNS address is spoofed, that real
9
+// IP would work as expected.
10
+package network
11
+
12
+import (
13
+	"errors"
14
+	"net"
15
+	"time"
16
+
17
+	"github.com/9seconds/mtg/v2/mtglib"
18
+)
19
+
20
+const (
21
+	// DefaultTimeout is a default timeout for establishing TCP connection.
22
+	DefaultTimeout = 10 * time.Second
23
+
24
+	// DefaultHTTPTimeout defines a default timeout for making HTTP request.
25
+	DefaultHTTPTimeout = 10 * time.Second
26
+
27
+	// DefaultIdleTimeout defines a timeout for idle HTTP connections
28
+	DefaultIdleTimeout = time.Minute
29
+
30
+	// DefaultTCPKeepAlivePeriod defines a time period between 2 consecuitive
31
+	// probes.
32
+	DefaultTCPKeepAlivePeriod = 10 * time.Second
33
+
34
+	// tcpLingerTimeout defines a number of seconds to wait for sending
35
+	// unacknowledged data.
36
+	tcpLingerTimeout = 1
37
+)
38
+
39
+var ErrCannotDial = errors.New("cannot dial to any address")
40
+
41
+type Network interface {
42
+	mtglib.Network
43
+
44
+	NativeDialer() *net.Dialer
45
+}

+ 70
- 0
network/v2/multi_network.go Visa fil

@@ -0,0 +1,70 @@
1
+package network
2
+
3
+import (
4
+	"context"
5
+	"errors"
6
+	"math/rand"
7
+	"net"
8
+	"net/http"
9
+
10
+	"github.com/9seconds/mtg/v2/essentials"
11
+)
12
+
13
+type multiNetwork struct {
14
+	networks []Network
15
+}
16
+
17
+func (m multiNetwork) Dial(network, address string) (essentials.Conn, error) {
18
+	return m.DialContext(context.Background(), network, address)
19
+}
20
+
21
+func (m multiNetwork) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
22
+	networks := m.networks
23
+
24
+	if len(networks) > 1 {
25
+		networks = make([]Network, len(m.networks))
26
+		copy(networks, m.networks)
27
+
28
+		rand.Shuffle(len(m.networks), func(i, j int) {
29
+			networks[i], networks[j] = networks[j], networks[i]
30
+		})
31
+	}
32
+
33
+	errs := make([]error, 1, len(networks)+1)
34
+	errs[0] = ErrCannotDial
35
+
36
+	for _, ntw := range networks {
37
+		conn, err := ntw.DialContext(ctx, network, address)
38
+		if err == nil {
39
+			return conn, nil
40
+		}
41
+
42
+		errs = append(errs, err)
43
+	}
44
+
45
+	return nil, errors.Join(errs...)
46
+}
47
+
48
+func (m multiNetwork) NativeDialer() *net.Dialer {
49
+	return m.networks[0].NativeDialer()
50
+}
51
+
52
+func (m multiNetwork) MakeHTTPClient(
53
+	dialFunc func(context.Context, string, string) (essentials.Conn, error),
54
+) *http.Client {
55
+	if dialFunc == nil {
56
+		dialFunc = m.DialContext
57
+	}
58
+
59
+	return m.networks[0].MakeHTTPClient(dialFunc)
60
+}
61
+
62
+func Join(networks ...Network) (Network, error) {
63
+	if len(networks) == 0 {
64
+		return nil, errors.New("cannot join no networks")
65
+	}
66
+
67
+	return multiNetwork{
68
+		networks: networks,
69
+	}, nil
70
+}

+ 88
- 0
network/v2/network.go Visa fil

@@ -0,0 +1,88 @@
1
+package network
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"net"
7
+	"net/http"
8
+	"time"
9
+
10
+	"github.com/9seconds/mtg/v2/essentials"
11
+)
12
+
13
+type network struct {
14
+	net.Dialer
15
+
16
+	httpTimeout time.Duration
17
+	idleTimeout time.Duration
18
+	userAgent   string
19
+}
20
+
21
+func (n *network) Dial(network, address string) (essentials.Conn, error) {
22
+	return n.DialContext(context.Background(), network, address)
23
+}
24
+
25
+func (n *network) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
26
+	switch network {
27
+	case "tcp", "tcp4", "tcp6":
28
+	default:
29
+		return nil, fmt.Errorf("unsupported network %s", network)
30
+	}
31
+
32
+	conn, err := n.Dialer.DialContext(ctx, network, address)
33
+	if err != nil {
34
+		return nil, err
35
+	}
36
+
37
+	tcpConn := conn.(*net.TCPConn)
38
+
39
+	return tcpConn, setCommonSocketOptions(tcpConn)
40
+}
41
+
42
+func (n *network) MakeHTTPClient(
43
+	dialFunc func(context.Context, string, string) (essentials.Conn, error),
44
+) *http.Client {
45
+	if dialFunc == nil {
46
+		dialFunc = n.DialContext
47
+	}
48
+
49
+	return &http.Client{
50
+		Timeout: n.httpTimeout,
51
+		Transport: networkHTTPTransport{
52
+			userAgent: n.userAgent,
53
+			next: &http.Transport{
54
+				IdleConnTimeout: n.idleTimeout,
55
+				DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
56
+					return dialFunc(ctx, network, address)
57
+				},
58
+			},
59
+		},
60
+	}
61
+}
62
+
63
+func (n *network) NativeDialer() *net.Dialer {
64
+	return &n.Dialer
65
+}
66
+
67
+func New(
68
+	dnsResolver *net.Resolver,
69
+	userAgent string,
70
+	tcpTimeout,
71
+	httpTimeout,
72
+	idleTimeout time.Duration,
73
+) Network {
74
+	if dnsResolver == nil {
75
+		dnsResolver = net.DefaultResolver
76
+	}
77
+
78
+	return &network{
79
+		Dialer: net.Dialer{
80
+			Timeout:       tcpTimeout,
81
+			Resolver:      dnsResolver,
82
+			FallbackDelay: -1,
83
+		},
84
+		userAgent:   userAgent,
85
+		idleTimeout: idleTimeout,
86
+		httpTimeout: httpTimeout,
87
+	}
88
+}

+ 36
- 0
network/v2/proxy_network.go Visa fil

@@ -0,0 +1,36 @@
1
+package network
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"net/url"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
9
+	"golang.org/x/net/proxy"
10
+)
11
+
12
+type proxyNetwork struct {
13
+	Network
14
+	client proxy.ContextDialer
15
+}
16
+
17
+func (p proxyNetwork) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
18
+	conn, err := p.client.DialContext(ctx, network, address)
19
+	if err != nil {
20
+		return nil, err
21
+	}
22
+
23
+	return essentials.WrapNetConn(conn), nil
24
+}
25
+
26
+func NewProxyNetwork(base Network, proxyURL *url.URL) (*proxyNetwork, error) {
27
+	socks, err := proxy.FromURL(proxyURL, base.NativeDialer())
28
+	if err != nil {
29
+		return nil, fmt.Errorf("cannot build proxy dialer: %w", err)
30
+	}
31
+
32
+	return &proxyNetwork{
33
+		Network: base,
34
+		client:  socks.(proxy.ContextDialer),
35
+	}, nil
36
+}

+ 27
- 0
network/v2/sockopts.go Visa fil

@@ -0,0 +1,27 @@
1
+package network
2
+
3
+import (
4
+	"fmt"
5
+	"net"
6
+)
7
+
8
+func setCommonSocketOptions(conn *net.TCPConn) error {
9
+	if err := conn.SetKeepAlivePeriod(DefaultTCPKeepAlivePeriod); err != nil {
10
+		return fmt.Errorf("cannot set time period of TCP keepalive probes: %w", err)
11
+	}
12
+
13
+	if err := conn.SetLinger(tcpLingerTimeout); err != nil {
14
+		return fmt.Errorf("cannot set TCP linger timeout: %w", err)
15
+	}
16
+
17
+	rawConn, err := conn.SyscallConn()
18
+	if err != nil {
19
+		return fmt.Errorf("cannot get underlying raw connection: %w", err)
20
+	}
21
+
22
+	if err := setSocketReuseAddrPort(rawConn); err != nil {
23
+		return fmt.Errorf("cannot setup SO_REUSEADDR/PORT: %w", err)
24
+	}
25
+
26
+	return nil
27
+}

+ 31
- 0
network/v2/sockopts_unix.go Visa fil

@@ -0,0 +1,31 @@
1
+//go:build !windows
2
+// +build !windows
3
+
4
+package network
5
+
6
+import (
7
+	"fmt"
8
+	"syscall"
9
+
10
+	"golang.org/x/sys/unix"
11
+)
12
+
13
+func setSocketReuseAddrPort(conn syscall.RawConn) error {
14
+	var err error
15
+
16
+	conn.Control(func(fd uintptr) { //nolint: errcheck
17
+		err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
18
+		if err != nil {
19
+			err = fmt.Errorf("cannot set SO_REUSEADDR: %w", err)
20
+
21
+			return
22
+		}
23
+
24
+		err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
25
+		if err != nil {
26
+			err = fmt.Errorf("cannot set SO_REUSEPORT: %w", err)
27
+		}
28
+	})
29
+
30
+	return err
31
+}

+ 10
- 0
network/v2/sockopts_windows.go Visa fil

@@ -0,0 +1,10 @@
1
+//go:build windows
2
+// +build windows
3
+
4
+package network
5
+
6
+import "syscall"
7
+
8
+func setSocketReuseAddrPort(conn syscall.RawConn) error {
9
+	return nil
10
+}

+ 127
- 0
network/v2/socks_proxy_test.go Visa fil

@@ -0,0 +1,127 @@
1
+package network_test
2
+
3
+import (
4
+	"net"
5
+	"net/url"
6
+	"sync"
7
+	"testing"
8
+
9
+	"github.com/9seconds/mtg/v2/network/v2"
10
+	"github.com/stretchr/testify/assert"
11
+	"github.com/stretchr/testify/require"
12
+	"github.com/stretchr/testify/suite"
13
+	"github.com/things-go/go-socks5"
14
+)
15
+
16
+type SocksProxyTestSuite struct {
17
+	EchoServerTestSuite
18
+
19
+	wg          sync.WaitGroup
20
+	baseNetwork network.Network
21
+
22
+	noAuthURL *url.URL
23
+	authURL   *url.URL
24
+
25
+	noAuthListener net.Listener
26
+	authListener   net.Listener
27
+
28
+	noAuthServer *socks5.Server
29
+	authServer   *socks5.Server
30
+}
31
+
32
+func (suite *SocksProxyTestSuite) SetupSuite() {
33
+	suite.EchoServerTestSuite.SetupSuite()
34
+
35
+	listener, err := net.Listen("tcp4", "127.0.0.1:0")
36
+	require.NoError(suite.T(), err)
37
+	suite.noAuthListener = listener
38
+
39
+	listener, err = net.Listen("tcp4", "127.0.0.1:0")
40
+	require.NoError(suite.T(), err)
41
+	suite.authListener = listener
42
+
43
+	suite.noAuthServer = socks5.NewServer()
44
+	suite.wg.Go(func() {
45
+		suite.noAuthServer.Serve(suite.noAuthListener)
46
+	})
47
+
48
+	suite.authServer = socks5.NewServer(
49
+		socks5.WithAuthMethods([]socks5.Authenticator{
50
+			socks5.UserPassAuthenticator{
51
+				Credentials: socks5.StaticCredentials{
52
+					"user": "pass",
53
+				},
54
+			},
55
+		}))
56
+	suite.wg.Go(func() {
57
+		suite.authServer.Serve(suite.authListener)
58
+	})
59
+
60
+	parsed, err := url.Parse("socks5://" + suite.noAuthListener.Addr().String())
61
+	require.NoError(suite.T(), err)
62
+	suite.noAuthURL = parsed
63
+
64
+	parsed, err = url.Parse("socks5://user:pass@" + suite.authListener.Addr().String())
65
+	require.NoError(suite.T(), err)
66
+	suite.authURL = parsed
67
+
68
+	suite.baseNetwork = network.New(nil, "mtg", 0, 0, 0)
69
+}
70
+
71
+func (suite *SocksProxyTestSuite) TestIncorrectSchema() {
72
+	parsed, err := url.Parse("http://hello")
73
+	suite.NoError(err)
74
+
75
+	_, err = network.NewProxyNetwork(suite.baseNetwork, parsed)
76
+	suite.Error(err)
77
+}
78
+
79
+func (suite *SocksProxyTestSuite) TestRead() {
80
+	testData := map[string][]*url.URL{
81
+		"noAuth": {suite.noAuthURL},
82
+		"auth":   {suite.authURL},
83
+		"both":   {suite.noAuthURL, suite.authURL},
84
+	}
85
+
86
+	for name, proxies := range testData {
87
+		suite.T().Run(name, func(t *testing.T) {
88
+			proxyNetworks := []network.Network{}
89
+
90
+			for _, u := range proxies {
91
+				value, err := network.NewProxyNetwork(suite.baseNetwork, u)
92
+				assert.NoError(t, err)
93
+				proxyNetworks = append(proxyNetworks, value)
94
+			}
95
+
96
+			netw, err := network.Join(proxyNetworks...)
97
+			assert.NoError(t, err)
98
+
99
+			conn, err := netw.Dial("tcp4", suite.EchoServerAddr())
100
+			assert.NoError(t, err)
101
+
102
+			data := []byte{1, 2, 3}
103
+			n, err := conn.Write(data)
104
+			assert.NoError(t, err)
105
+			assert.Equal(t, len(data), n)
106
+
107
+			toRead := []byte{1, 2, 3, 4, 5}
108
+			n, err = conn.Read(toRead)
109
+			assert.NoError(t, err)
110
+			assert.Equal(t, len(data), n)
111
+			assert.Equal(t, data, toRead[:n])
112
+			assert.NotEqual(t, data, toRead)
113
+		})
114
+	}
115
+}
116
+
117
+func (suite *SocksProxyTestSuite) TearDownSuite() {
118
+	suite.noAuthListener.Close()
119
+	suite.authListener.Close()
120
+	suite.wg.Wait()
121
+	suite.EchoServerTestSuite.TearDownSuite()
122
+}
123
+
124
+func TestSocksProxy(t *testing.T) {
125
+	t.Parallel()
126
+	suite.Run(t, &SocksProxyTestSuite{})
127
+}

Laddar…
Avbryt
Spara