Browse Source

Change network to accept DialFunc

tags/v2.0.0-rc1
9seconds 5 years ago
parent
commit
d6566e5dfb

+ 7
- 17
cli/access.go View File

38
 	Hex        bool   `help:"Print secret in hex encoding."`
38
 	Hex        bool   `help:"Print secret in hex encoding."`
39
 }
39
 }
40
 
40
 
41
-func (c *Access) Run(cli *CLI) error {
42
-	if err := c.ReadConfig(cli.Access.ConfigPath); err != nil {
41
+func (c *Access) Run(cli *CLI, version string) error {
42
+	if err := c.ReadConfig(cli.Access.ConfigPath, version); err != nil {
43
 		return fmt.Errorf("cannot init config: %w", err)
43
 		return fmt.Errorf("cannot init config: %w", err)
44
 	}
44
 	}
45
 
45
 
74
 }
74
 }
75
 
75
 
76
 func (c *Access) getIP(protocol string) net.IP {
76
 func (c *Access) getIP(protocol string) net.IP {
77
-	client := c.network.MakeHTTPClient(0)
78
-	client.Transport = &http.Transport{
79
-		DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
80
-			return c.network.DialContext(ctx, protocol, address)
81
-		},
82
-	}
83
-
84
-	c.network.PrepareHTTPClient(client)
85
-
86
-	req, err := http.NewRequest(http.MethodGet, "https://ifconfig.co", nil)
87
-	if err != nil {
88
-		panic(err)
89
-	}
77
+	client := c.network.MakeHTTPClient(func(ctx context.Context, network, address string) (net.Conn, error) {
78
+		return c.network.DialContext(ctx, protocol, address)
79
+	})
90
 
80
 
91
-	resp, err := client.Do(req)
81
+	resp, err := client.Get("https://ifconfig.co") // nolint: noctx
92
 	if err != nil {
82
 	if err != nil {
93
 		return nil
83
 		return nil
94
 	}
84
 	}
98
 	}
88
 	}
99
 
89
 
100
 	defer func() {
90
 	defer func() {
101
-		io.Copy(ioutil.Discard, resp.Body)
91
+		io.Copy(ioutil.Discard, resp.Body) // nolint: errcheck
102
 		resp.Body.Close()
92
 		resp.Body.Close()
103
 	}()
93
 	}()
104
 
94
 

+ 7
- 6
cli/base.go View File

15
 	conf    *config.Config
15
 	conf    *config.Config
16
 }
16
 }
17
 
17
 
18
-func (b *base) ReadConfig(path string) error {
18
+func (b *base) ReadConfig(path, version string) error {
19
 	content, err := ioutil.ReadFile(path)
19
 	content, err := ioutil.ReadFile(path)
20
 	if err != nil {
20
 	if err != nil {
21
 		return fmt.Errorf("cannot read config file: %w", err)
21
 		return fmt.Errorf("cannot read config file: %w", err)
26
 		return fmt.Errorf("cannot parse config: %w", err)
26
 		return fmt.Errorf("cannot parse config: %w", err)
27
 	}
27
 	}
28
 
28
 
29
-	ntw, err := b.makeNetwork(conf)
29
+	ntw, err := b.makeNetwork(conf, version)
30
 	if err != nil {
30
 	if err != nil {
31
 		return fmt.Errorf("cannot build a network: %w", err)
31
 		return fmt.Errorf("cannot build a network: %w", err)
32
 	}
32
 	}
37
 	return nil
37
 	return nil
38
 }
38
 }
39
 
39
 
40
-func (b *base) makeNetwork(conf *config.Config) (network.Network, error) {
40
+func (b *base) makeNetwork(conf *config.Config, version string) (network.Network, error) {
41
 	tcpTimeout := conf.Network.Timeout.TCP.Value(network.DefaultTimeout)
41
 	tcpTimeout := conf.Network.Timeout.TCP.Value(network.DefaultTimeout)
42
 	idleTimeout := conf.Network.Timeout.Idle.Value(network.DefaultIdleTimeout)
42
 	idleTimeout := conf.Network.Timeout.Idle.Value(network.DefaultIdleTimeout)
43
 	dohIP := conf.Network.DOHIP.Value(net.ParseIP(network.DefaultDOHHostname)).String()
43
 	dohIP := conf.Network.DOHIP.Value(net.ParseIP(network.DefaultDOHHostname)).String()
44
 	bufferSize := conf.TCPBuffer.Value(network.DefaultBufferSize)
44
 	bufferSize := conf.TCPBuffer.Value(network.DefaultBufferSize)
45
+	userAgent := "mtg/" + version
45
 
46
 
46
 	baseDialer, err := network.NewDefaultDialer(tcpTimeout, int(bufferSize))
47
 	baseDialer, err := network.NewDefaultDialer(tcpTimeout, int(bufferSize))
47
 	if err != nil {
48
 	if err != nil {
58
 
59
 
59
 	switch len(proxyURLs) {
60
 	switch len(proxyURLs) {
60
 	case 0:
61
 	case 0:
61
-		return network.NewNetwork(baseDialer, dohIP, idleTimeout)
62
+		return network.NewNetwork(baseDialer, userAgent, dohIP, idleTimeout)
62
 	case 1:
63
 	case 1:
63
 		socksDialer, err := network.NewSocks5Dialer(baseDialer, proxyURLs[0])
64
 		socksDialer, err := network.NewSocks5Dialer(baseDialer, proxyURLs[0])
64
 		if err != nil {
65
 		if err != nil {
65
 			return nil, fmt.Errorf("cannot build socks5 dialer: %w", err)
66
 			return nil, fmt.Errorf("cannot build socks5 dialer: %w", err)
66
 		}
67
 		}
67
 
68
 
68
-		return network.NewNetwork(socksDialer, dohIP, idleTimeout)
69
+		return network.NewNetwork(socksDialer, userAgent, dohIP, idleTimeout)
69
 	}
70
 	}
70
 
71
 
71
 	socksDialer, err := network.NewLoadBalancedSocks5Dialer(baseDialer, proxyURLs)
72
 	socksDialer, err := network.NewLoadBalancedSocks5Dialer(baseDialer, proxyURLs)
73
 		return nil, fmt.Errorf("cannot build socks5 dialer: %w", err)
74
 		return nil, fmt.Errorf("cannot build socks5 dialer: %w", err)
74
 	}
75
 	}
75
 
76
 
76
-	return network.NewNetwork(socksDialer, dohIP, idleTimeout)
77
+	return network.NewNetwork(socksDialer, userAgent, dohIP, idleTimeout)
77
 }
78
 }

+ 1
- 3
cli/cli.go View File

1
 package cli
1
 package cli
2
 
2
 
3
-import (
4
-	"github.com/alecthomas/kong"
5
-)
3
+import "github.com/alecthomas/kong"
6
 
4
 
7
 type CLI struct {
5
 type CLI struct {
8
 	GenerateSecret GenerateSecret   `cmd help:"Generate new proxy secret"` // nolint: govet
6
 	GenerateSecret GenerateSecret   `cmd help:"Generate new proxy secret"` // nolint: govet

+ 1
- 1
cli/generate_secret.go View File

13
 	Hex      bool   `help:"Print secret in hex encoding."`
13
 	Hex      bool   `help:"Print secret in hex encoding."`
14
 }
14
 }
15
 
15
 
16
-func (c *GenerateSecret) Run(cli *CLI) error { // nolint: unparam
16
+func (c *GenerateSecret) Run(cli *CLI, _ string) error {
17
 	secret := mtglib.GenerateSecret(cli.GenerateSecret.HostName)
17
 	secret := mtglib.GenerateSecret(cli.GenerateSecret.HostName)
18
 
18
 
19
 	if cli.GenerateSecret.Hex {
19
 	if cli.GenerateSecret.Hex {

+ 1
- 1
config/type_bytes.go View File

30
 	return nil
30
 	return nil
31
 }
31
 }
32
 
32
 
33
-func (c TypeBytes) MarshalText() ([]byte, error) { // nolint: unparam
33
+func (c TypeBytes) MarshalText() ([]byte, error) {
34
 	return []byte(c.String()), nil
34
 	return []byte(c.String()), nil
35
 }
35
 }
36
 
36
 

+ 1
- 1
config/type_duration.go View File

29
 	return nil
29
 	return nil
30
 }
30
 }
31
 
31
 
32
-func (c TypeDuration) MarshalText() ([]byte, error) { // nolint: unparam
32
+func (c TypeDuration) MarshalText() ([]byte, error) {
33
 	return []byte(c.value.String()), nil
33
 	return []byte(c.value.String()), nil
34
 }
34
 }
35
 
35
 

+ 1
- 1
config/type_float.go View File

24
 	return nil
24
 	return nil
25
 }
25
 }
26
 
26
 
27
-func (c *TypeFloat) MarshalText() ([]byte, error) { // nolint: unparam
27
+func (c *TypeFloat) MarshalText() ([]byte, error) {
28
 	return []byte(c.String()), nil
28
 	return []byte(c.String()), nil
29
 }
29
 }
30
 
30
 

+ 3
- 3
config/type_hostport.go View File

32
 	return nil
32
 	return nil
33
 }
33
 }
34
 
34
 
35
-func (c TypeHostPort) MarshalText() ([]byte, error) { // nolint: unparam
35
+func (c TypeHostPort) MarshalText() ([]byte, error) {
36
 	return []byte(c.String()), nil
36
 	return []byte(c.String()), nil
37
 }
37
 }
38
 
38
 
41
 }
41
 }
42
 
42
 
43
 func (c TypeHostPort) HostValue(defaultValue net.IP) net.IP {
43
 func (c TypeHostPort) HostValue(defaultValue net.IP) net.IP {
44
-    return c.host.Value(defaultValue)
44
+	return c.host.Value(defaultValue)
45
 }
45
 }
46
 
46
 
47
 func (c TypeHostPort) PortValue(defaultValue uint) uint {
47
 func (c TypeHostPort) PortValue(defaultValue uint) uint {
48
-    return c.port.Value(defaultValue)
48
+	return c.port.Value(defaultValue)
49
 }
49
 }
50
 
50
 
51
 func (c TypeHostPort) Value(defaultHostValue net.IP, defaultPortValue uint) string {
51
 func (c TypeHostPort) Value(defaultHostValue net.IP, defaultPortValue uint) string {

+ 2
- 2
config/type_http_path.go View File

6
 	value string
6
 	value string
7
 }
7
 }
8
 
8
 
9
-func (c *TypeHTTPPath) UnmarshalText(data []byte) error { // nolint: unparam
9
+func (c *TypeHTTPPath) UnmarshalText(data []byte) error {
10
 	if len(data) > 0 {
10
 	if len(data) > 0 {
11
 		c.value = "/" + strings.Trim(string(data), "/")
11
 		c.value = "/" + strings.Trim(string(data), "/")
12
 	}
12
 	}
14
 	return nil
14
 	return nil
15
 }
15
 }
16
 
16
 
17
-func (c TypeHTTPPath) MarshalText() ([]byte, error) { // nolint: unparam
17
+func (c TypeHTTPPath) MarshalText() ([]byte, error) {
18
 	return []byte(c.String()), nil
18
 	return []byte(c.String()), nil
19
 }
19
 }
20
 
20
 

+ 1
- 1
config/type_ip.go View File

24
 	return nil
24
 	return nil
25
 }
25
 }
26
 
26
 
27
-func (c *TypeIP) MarshalText() ([]byte, error) { // nolint: unparam
27
+func (c *TypeIP) MarshalText() ([]byte, error) {
28
 	return []byte(c.String()), nil
28
 	return []byte(c.String()), nil
29
 }
29
 }
30
 
30
 

+ 1
- 1
config/type_metric_prefix.go View File

25
 	return nil
25
 	return nil
26
 }
26
 }
27
 
27
 
28
-func (c TypeMetricPrefix) MarshalText() ([]byte, error) { // nolint: unparam
28
+func (c TypeMetricPrefix) MarshalText() ([]byte, error) {
29
 	return []byte(c.String()), nil
29
 	return []byte(c.String()), nil
30
 }
30
 }
31
 
31
 

+ 2
- 2
config/type_port.go View File

28
 	return nil
28
 	return nil
29
 }
29
 }
30
 
30
 
31
-func (c *TypePort) MarshalJSON() ([]byte, error) { // nolint: unparam
32
-    return []byte(c.String()), nil
31
+func (c *TypePort) MarshalJSON() ([]byte, error) {
32
+	return []byte(c.String()), nil
33
 }
33
 }
34
 
34
 
35
 func (c TypePort) String() string {
35
 func (c TypePort) String() string {

+ 1
- 1
config/type_prefer_ip.go View File

26
 	return nil
26
 	return nil
27
 }
27
 }
28
 
28
 
29
-func (c TypePreferIP) MarshalText() ([]byte, error) { // nolint: unparam
29
+func (c TypePreferIP) MarshalText() ([]byte, error) {
30
 	return []byte(c.value), nil
30
 	return []byte(c.value), nil
31
 }
31
 }
32
 
32
 

+ 1
- 1
config/type_url.go View File

24
 	return nil
24
 	return nil
25
 }
25
 }
26
 
26
 
27
-func (c *TypeURL) MarshalText() ([]byte, error) { // nolint: unparam
27
+func (c *TypeURL) MarshalText() ([]byte, error) {
28
 	return []byte(c.String()), nil
28
 	return []byte(c.String()), nil
29
 }
29
 }
30
 
30
 

+ 1
- 1
main.go View File

19
 		"version":      version,
19
 		"version":      version,
20
 	})
20
 	})
21
 
21
 
22
-	ctx.FatalIfErrorf(ctx.Run(cli))
22
+	ctx.FatalIfErrorf(ctx.Run(cli, version))
23
 }
23
 }

+ 5
- 4
mtglib/network/init.go View File

11
 const (
11
 const (
12
 	DefaultTimeout     = 10 * time.Second
12
 	DefaultTimeout     = 10 * time.Second
13
 	DefaultIdleTimeout = time.Minute
13
 	DefaultIdleTimeout = time.Minute
14
-	DefaultHTTPTimeout = 5 * time.Second
15
 	DefaultBufferSize  = 4096
14
 	DefaultBufferSize  = 4096
16
 
15
 
17
 	ProxyDialerOpenThreshold        = 5
16
 	ProxyDialerOpenThreshold        = 5
20
 
19
 
21
 	DefaultDOHHostname = "9.9.9.9"
20
 	DefaultDOHHostname = "9.9.9.9"
22
 
21
 
23
-	DNSTimeout = 5 * time.Second
22
+	DNSTimeout  = 5 * time.Second
23
+	HTTPTimeout = 10 * time.Second
24
 )
24
 )
25
 
25
 
26
 var (
26
 var (
28
 	ErrCannotDialWithAllProxies = errors.New("cannot dial with all proxies")
28
 	ErrCannotDialWithAllProxies = errors.New("cannot dial with all proxies")
29
 )
29
 )
30
 
30
 
31
+type DialFunc func(ctx context.Context, protocol, address string) (net.Conn, error)
32
+
31
 type Dialer interface {
33
 type Dialer interface {
32
 	Dial(network, address string) (net.Conn, error)
34
 	Dial(network, address string) (net.Conn, error)
33
 	DialContext(ctx context.Context, network, address string) (net.Conn, error)
35
 	DialContext(ctx context.Context, network, address string) (net.Conn, error)
37
 	Dialer
39
 	Dialer
38
 
40
 
39
 	DNSResolve(network, hostname string) (ips []string, err error)
41
 	DNSResolve(network, hostname string) (ips []string, err error)
40
-	MakeHTTPClient(timeout time.Duration) *http.Client
42
+	MakeHTTPClient(DialFunc) *http.Client
41
 	IdleTimeout() time.Duration
43
 	IdleTimeout() time.Duration
42
-	PrepareHTTPClient(*http.Client)
43
 }
44
 }

+ 31
- 14
mtglib/network/network.go View File

12
 	doh "github.com/babolivier/go-doh-client"
12
 	doh "github.com/babolivier/go-doh-client"
13
 )
13
 )
14
 
14
 
15
+type networkHTTPTransport struct {
16
+	userAgent string
17
+	next      http.RoundTripper
18
+}
19
+
20
+func (n networkHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
21
+	req.Header.Set("User-Agent", n.userAgent)
22
+
23
+	return n.next.RoundTrip(req)
24
+}
25
+
15
 type network struct {
26
 type network struct {
16
-	idleTimeout time.Duration
17
 	dialer      Dialer
27
 	dialer      Dialer
18
 	dns         doh.Resolver
28
 	dns         doh.Resolver
29
+	idleTimeout time.Duration
30
+	userAgent   string
19
 }
31
 }
20
 
32
 
21
 func (n *network) Dial(protocol, address string) (net.Conn, error) {
33
 func (n *network) Dial(protocol, address string) (net.Conn, error) {
107
 	return n.idleTimeout
119
 	return n.idleTimeout
108
 }
120
 }
109
 
121
 
110
-func (n *network) MakeHTTPClient(timeout time.Duration) *http.Client {
111
-	if timeout <= 0 {
112
-		timeout = DefaultHTTPTimeout
113
-	}
114
-
115
-	return &http.Client{
116
-		Timeout: timeout,
117
-		Transport: &http.Transport{
118
-			DialContext: n.DialContext,
119
-		},
122
+func (n *network) MakeHTTPClient(dialFunc DialFunc) *http.Client {
123
+	if dialFunc == nil {
124
+		dialFunc = n.DialContext
120
 	}
125
 	}
121
-}
122
 
126
 
123
-func (n *network) PatchHTTPClient(_ *http.Client) {
127
+	return makeHTTPClient(n.userAgent, dialFunc)
124
 }
128
 }
125
 
129
 
126
-func NewNetwork(dialer Dialer, dohHostname string, idleTimeout time.Duration) (Network, error) {
130
+func NewNetwork(dialer Dialer, userAgent, dohHostname string, idleTimeout time.Duration) (Network, error) {
127
 	switch {
131
 	switch {
128
 	case idleTimeout < 0:
132
 	case idleTimeout < 0:
129
 		return nil, fmt.Errorf("timeout should be positive number %s", idleTimeout)
133
 		return nil, fmt.Errorf("timeout should be positive number %s", idleTimeout)
138
 	return &network{
142
 	return &network{
139
 		dialer:      dialer,
143
 		dialer:      dialer,
140
 		idleTimeout: idleTimeout,
144
 		idleTimeout: idleTimeout,
145
+		userAgent:   userAgent,
141
 		dns: doh.Resolver{
146
 		dns: doh.Resolver{
142
 			Host:  dohHostname,
147
 			Host:  dohHostname,
143
 			Class: doh.IN,
148
 			Class: doh.IN,
150
 		},
155
 		},
151
 	}, nil
156
 	}, nil
152
 }
157
 }
158
+
159
+func makeHTTPClient(userAgent string, dialFunc DialFunc) *http.Client {
160
+	return &http.Client{
161
+		Timeout: HTTPTimeout,
162
+		Transport: networkHTTPTransport{
163
+			userAgent: userAgent,
164
+			next: &http.Transport{
165
+				DialContext: dialFunc,
166
+			},
167
+		},
168
+	}
169
+}

+ 0
- 57
utils.go View File

1
-package main
2
-
3
-import (
4
-	"fmt"
5
-	"io"
6
-	"io/ioutil"
7
-	"net"
8
-	"net/http"
9
-	"net/url"
10
-
11
-	"github.com/9seconds/mtg/v2/config"
12
-	"github.com/9seconds/mtg/v2/mtglib/network"
13
-)
14
-
15
-func makeNetwork(conf *config.Config) (network.Network, error) {
16
-	tcpTimeout := conf.Network.Timeout.TCP.Value(network.DefaultTimeout)
17
-	idleTimeout := conf.Network.Timeout.Idle.Value(network.DefaultIdleTimeout)
18
-	dohIP := conf.Network.DOHIP.Value(net.ParseIP(network.DefaultDOHHostname)).String()
19
-	bufferSize := conf.TCPBuffer.Value(network.DefaultBufferSize)
20
-
21
-	baseDialer, err := network.NewDefaultDialer(tcpTimeout, int(bufferSize))
22
-	if err != nil {
23
-		return nil, fmt.Errorf("cannot build a default dialer: %w", err)
24
-	}
25
-
26
-	proxyURLs := make([]*url.URL, 0, len(conf.Network.Proxies))
27
-
28
-	for _, v := range conf.Network.Proxies {
29
-		if value := v.Value(nil); value != nil {
30
-			proxyURLs = append(proxyURLs, v.Value(nil))
31
-		}
32
-	}
33
-
34
-	switch len(proxyURLs) {
35
-	case 0:
36
-		return network.NewNetwork(baseDialer, dohIP, idleTimeout)
37
-	case 1:
38
-		socksDialer, err := network.NewSocks5Dialer(baseDialer, proxyURLs[0])
39
-		if err != nil {
40
-			return nil, fmt.Errorf("cannot build socks5 dialer: %w", err)
41
-		}
42
-
43
-		return network.NewNetwork(socksDialer, dohIP, idleTimeout)
44
-	}
45
-
46
-	socksDialer, err := network.NewLoadBalancedSocks5Dialer(baseDialer, proxyURLs)
47
-	if err != nil {
48
-		return nil, fmt.Errorf("cannot build socks5 dialer: %w", err)
49
-	}
50
-
51
-	return network.NewNetwork(socksDialer, dohIP, idleTimeout)
52
-}
53
-
54
-func exhaustResponse(response *http.Response) {
55
-	io.Copy(ioutil.Discard, response.Body) // nolint: errcheck
56
-	response.Body.Close()
57
-}

Loading…
Cancel
Save