Просмотр исходного кода

Resolve URLs by using multiple services

This PR has an intention of resolving URLs by using multiple endpoints
that identify an IP address of the service. This is handy if one service
is blocked for some reason.

The detection mechanism follows this logic:

1. It tries to access all services in parallel
2. If service respond with some error (like, no route to host for IPv6),
   then we accurately collect those errors and return a merged one
3. In case of the first IP resolved, we immediately return it.

Also, this PR refactors how access and SNI check are performed.
pull/557/head
9seconds 2 недель назад
Родитель
Сommit
2145159f01
6 измененных файлов: 285 добавлений и 186 удалений
  1. 12
    10
      internal/cli/access.go
  2. 63
    39
      internal/cli/doctor.go
  3. 138
    0
      internal/cli/get_ip.go
  4. 16
    41
      internal/cli/run_proxy.go
  5. 56
    45
      internal/cli/sni_check.go
  6. 0
    51
      internal/cli/utils.go

+ 12
- 10
internal/cli/access.go Просмотреть файл

1
 package cli
1
 package cli
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"encoding/json"
5
 	"encoding/json"
5
 	"fmt"
6
 	"fmt"
6
 	"net"
7
 	"net"
54
 		return fmt.Errorf("cannot init network: %w", err)
55
 		return fmt.Errorf("cannot init network: %w", err)
55
 	}
56
 	}
56
 
57
 
57
-	wg := &sync.WaitGroup{}
58
+	ctx, cancel := context.WithTimeout(context.Background(), getIPTimeout)
59
+	defer cancel()
58
 
60
 
61
+	wg := &sync.WaitGroup{}
59
 	wg.Go(func() {
62
 	wg.Go(func() {
60
 		ip := a.PublicIPv4
63
 		ip := a.PublicIPv4
64
+
61
 		if ip == nil {
65
 		if ip == nil {
62
 			ip = conf.PublicIPv4.Get(nil)
66
 			ip = conf.PublicIPv4.Get(nil)
63
 		}
67
 		}
68
+
64
 		if ip == nil {
69
 		if ip == nil {
65
-			ip = getIP(ntw, "tcp4")
70
+			ip, _ = getIP(ctx, ntw, "tcp4")
66
 		}
71
 		}
67
 
72
 
68
 		if ip != nil {
73
 		if ip != nil {
69
-			ip = ip.To4()
74
+			resp.IPv4 = a.makeURLs(conf, ip.To4())
70
 		}
75
 		}
71
-
72
-		resp.IPv4 = a.makeURLs(conf, ip)
73
 	})
76
 	})
74
 	wg.Go(func() {
77
 	wg.Go(func() {
75
 		ip := a.PublicIPv6
78
 		ip := a.PublicIPv6
79
+
76
 		if ip == nil {
80
 		if ip == nil {
77
 			ip = conf.PublicIPv6.Get(nil)
81
 			ip = conf.PublicIPv6.Get(nil)
78
 		}
82
 		}
83
+
79
 		if ip == nil {
84
 		if ip == nil {
80
-			ip = getIP(ntw, "tcp6")
85
+			ip, _ = getIP(ctx, ntw, "tcp6")
81
 		}
86
 		}
82
 
87
 
83
 		if ip != nil {
88
 		if ip != nil {
84
-			ip = ip.To16()
89
+			resp.IPv6 = a.makeURLs(conf, ip.To16())
85
 		}
90
 		}
86
-
87
-		resp.IPv6 = a.makeURLs(conf, ip)
88
 	})
91
 	})
89
-
90
 	wg.Wait()
92
 	wg.Wait()
91
 
93
 
92
 	encoder := json.NewEncoder(os.Stdout)
94
 	encoder := json.NewEncoder(os.Stdout)

+ 63
- 39
internal/cli/doctor.go Просмотреть файл

24
 )
24
 )
25
 
25
 
26
 var (
26
 var (
27
+	funcs = template.FuncMap{
28
+		"join": strings.Join,
29
+	}
30
+
27
 	tplError = template.Must(
31
 	tplError = template.Must(
28
-		template.New("").Parse("  ‼️ {{ .description }}: {{ .error }}\n"),
32
+		template.New("").
33
+			Funcs(funcs).
34
+			Parse("  ‼️ {{ .description }}: {{ .error }}\n"),
29
 	)
35
 	)
30
 
36
 
31
 	tplWDeprecatedConfig = template.Must(
37
 	tplWDeprecatedConfig = template.Must(
32
 		template.New("").
38
 		template.New("").
39
+			Funcs(funcs).
33
 			Parse(`  ⚠️ Option {{ .old | printf "%q" }}{{ if .old_section }} from section [{{ .old_section }}]{{ end }} is deprecated and will be removed in v{{ .when }}. Please use {{ .new | printf "%q" }}{{ if .new_section }} in [{{ .new_section }}] section{{ end }} instead.` + "\n"),
40
 			Parse(`  ⚠️ Option {{ .old | printf "%q" }}{{ if .old_section }} from section [{{ .old_section }}]{{ end }} is deprecated and will be removed in v{{ .when }}. Please use {{ .new | printf "%q" }}{{ if .new_section }} in [{{ .new_section }}] section{{ end }} instead.` + "\n"),
34
 	)
41
 	)
35
 
42
 
36
 	tplOTimeSkewness = template.Must(
43
 	tplOTimeSkewness = template.Must(
37
 		template.New("").
44
 		template.New("").
45
+			Funcs(funcs).
38
 			Parse("  ✅ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}\n"),
46
 			Parse("  ✅ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}\n"),
39
 	)
47
 	)
40
 	tplWTimeSkewness = template.Must(
48
 	tplWTimeSkewness = template.Must(
41
 		template.New("").
49
 		template.New("").
50
+			Funcs(funcs).
42
 			Parse("  ⚠️ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. Please check ntp.\n"),
51
 			Parse("  ⚠️ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. Please check ntp.\n"),
43
 	)
52
 	)
44
 	tplETimeSkewness = template.Must(
53
 	tplETimeSkewness = template.Must(
45
 		template.New("").
54
 		template.New("").
55
+			Funcs(funcs).
46
 			Parse("  ❌ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. You will get many rejected connections!\n"),
56
 			Parse("  ❌ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. You will get many rejected connections!\n"),
47
 	)
57
 	)
48
 
58
 
49
 	tplODCConnect = template.Must(
59
 	tplODCConnect = template.Must(
50
-		template.New("").Parse("  ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"),
60
+		template.New("").
61
+			Funcs(funcs).
62
+			Parse("  ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"),
51
 	)
63
 	)
52
 	tplEDCConnect = template.Must(
64
 	tplEDCConnect = template.Must(
53
-		template.New("").Parse("  ❌ DC {{ .dc }}: {{ .error }}\n"),
65
+		template.New("").
66
+			Funcs(funcs).
67
+			Parse("  ❌ DC {{ .dc }}: {{ .error }}\n"),
54
 	)
68
 	)
55
 
69
 
56
 	tplODNSSNIMatch = template.Must(
70
 	tplODNSSNIMatch = template.Must(
57
-		template.New("").Parse("  ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"),
71
+		template.New("").
72
+			Funcs(funcs).
73
+			Parse("  ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"),
58
 	)
74
 	)
59
 	tplEDNSSNIMatch = template.Must(
75
 	tplEDNSSNIMatch = template.Must(
60
-		template.New("").Parse("  ❌ Hostname {{ .hostname }} {{ if .resolved }}resolves to {{ .resolved }}, but the proxy's public IP is {{ if .ip4 }}{{ .ip4 }}{{ else }}<not detected>{{ end }} (IPv4) / {{ if .ip6 }}{{ .ip6 }}{{ else }}<not detected>{{ end }} (IPv6) — none of the resolved addresses match{{ else }}cannot be resolved to any host{{ end }}\n"),
76
+		template.New("").
77
+			Funcs(funcs).
78
+			Parse(`  ❌ Hostname {{ .hostname }} resolves to {{ join ", " .resolved }} but public IP is {{ .ip }}` + "\n"),
61
 	)
79
 	)
62
 
80
 
63
 	tplOFrontingDomain = template.Must(
81
 	tplOFrontingDomain = template.Must(
64
-		template.New("").Parse("  ✅ {{ .address }} is reachable\n"),
82
+		template.New("").
83
+			Funcs(funcs).
84
+			Parse("  ✅ {{ .address }} is reachable\n"),
65
 	)
85
 	)
66
 	tplEFrontingDomain = template.Must(
86
 	tplEFrontingDomain = template.Must(
67
-		template.New("").Parse("  ❌ {{ .address }}: {{ .error }}\n"),
87
+		template.New("").
88
+			Funcs(funcs).
89
+			Parse("  ❌ {{ .address }}: {{ .error }}\n"),
68
 	)
90
 	)
69
 )
91
 )
70
 
92
 
71
 type Doctor struct {
93
 type Doctor struct {
72
 	conf *config.Config
94
 	conf *config.Config
73
 
95
 
74
-	ConfigPath      string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"` //nolint: lll
96
+	ConfigPath      string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"`                                                                       //nolint: lll
75
 	SkipNativeCheck bool   `kong:"help='Skip the native network connectivity check (useful when proxy chaining is configured and direct egress is not expected to work).',name='skip-native-check'"` //nolint: lll
97
 	SkipNativeCheck bool   `kong:"help='Skip the native network connectivity check (useful when proxy chaining is configured and direct egress is not expected to work).',name='skip-native-check'"` //nolint: lll
76
 }
98
 }
77
 
99
 
371
 }
393
 }
372
 
394
 
373
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
395
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
374
-	res := runSNICheck(context.Background(), resolver, d.conf, ntw)
375
-
376
-	if res.ResolveErr != nil {
396
+	res, err := runSNICheck(context.Background(), d.conf, resolver, ntw)
397
+	if err != nil {
377
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
398
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
378
 			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
399
 			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
379
-			"error":       res.ResolveErr,
400
+			"error":       err,
380
 		})
401
 		})
381
 		return false
402
 		return false
382
 	}
403
 	}
383
 
404
 
384
-	if !res.PublicIPKnown() {
405
+	if res.OurIP4 == "" && res.OurIP6 == "" {
385
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
406
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
386
 			"description": "cannot detect public IP address",
407
 			"description": "cannot detect public IP address",
387
 			"error":       errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
408
 			"error":       errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
389
 		return false
410
 		return false
390
 	}
411
 	}
391
 
412
 
392
-	if res.IPv4Match || res.IPv6Match {
393
-		var matched net.IP
413
+	ok := true
394
 
414
 
395
-		for _, ip := range res.Resolved {
396
-			if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) ||
397
-				(res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) {
398
-				matched = ip
399
-				break
400
-			}
415
+	if len(res.ResolvedIP4) > 0 {
416
+		if slices.Contains(res.ResolvedIP4, res.OurIP4) {
417
+			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
418
+				"ip":       res.OurIP4,
419
+				"hostname": d.conf.Secret.Host,
420
+			})
421
+		} else {
422
+			tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
423
+				"ip":       res.OurIP4,
424
+				"resolved": res.ResolvedIP4,
425
+				"hostname": d.conf.Secret.Host,
426
+			})
427
+			ok = false
401
 		}
428
 		}
402
-
403
-		tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
404
-			"ip":       matched,
405
-			"hostname": d.conf.Secret.Host,
406
-		})
407
-		return true
408
 	}
429
 	}
409
-
410
-	strAddresses := make([]string, 0, len(res.Resolved))
411
-	for _, ip := range res.Resolved {
412
-		strAddresses = append(strAddresses, `"`+ip.String()+`"`)
430
+	if len(res.ResolvedIP6) > 0 {
431
+		if slices.Contains(res.ResolvedIP6, res.OurIP6) {
432
+			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
433
+				"ip":       res.OurIP6,
434
+				"hostname": d.conf.Secret.Host,
435
+			})
436
+		} else {
437
+			tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
438
+				"ip":       res.OurIP6,
439
+				"resolved": res.ResolvedIP6,
440
+				"hostname": d.conf.Secret.Host,
441
+			})
442
+			ok = false
443
+		}
413
 	}
444
 	}
414
 
445
 
415
-	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
416
-		"hostname": d.conf.Secret.Host,
417
-		"resolved": strings.Join(strAddresses, ", "),
418
-		"ip4":      res.OurIPv4,
419
-		"ip6":      res.OurIPv6,
420
-	})
421
-
422
-	return false
446
+	return ok
423
 }
447
 }

+ 138
- 0
internal/cli/get_ip.go Просмотреть файл

1
+package cli
2
+
3
+import (
4
+	"bytes"
5
+	"context"
6
+	"errors"
7
+	"fmt"
8
+	"io"
9
+	"net"
10
+	"net/http"
11
+	"sync"
12
+	"time"
13
+
14
+	"github.com/9seconds/mtg/v2/essentials"
15
+	"github.com/9seconds/mtg/v2/mtglib"
16
+)
17
+
18
+const (
19
+	getIPTimeout = 5 * time.Second
20
+)
21
+
22
+var getIPServicesPlain = []string{
23
+	"https://ifconfig.co",
24
+	"https://ifconfig.me",
25
+	"https://api.ipify.org",
26
+	"https://ipecho.net/plain",
27
+}
28
+
29
+func getIP(ctx context.Context, ntw mtglib.Network, protocol string) (net.IP, error) {
30
+	ctx, cancel := context.WithTimeout(ctx, getIPTimeout)
31
+	defer cancel()
32
+
33
+	ctx, cancelCause := context.WithCancelCause(ctx)
34
+	defer cancelCause(nil)
35
+
36
+	var ip net.IP
37
+
38
+	rvChan := make(chan net.IP)
39
+	errChan := make(chan error)
40
+	errs := []error{}
41
+	wg := &sync.WaitGroup{}
42
+	dialer := ntw.NativeDialer()
43
+	client := ntw.MakeHTTPClient(func(_ context.Context, network, address string) (essentials.Conn, error) {
44
+		conn, err := dialer.DialContext(ctx, protocol, address)
45
+		if err != nil {
46
+			return nil, err
47
+		}
48
+		return essentials.WrapNetConn(conn), err
49
+	})
50
+
51
+	for _, url := range getIPServicesPlain {
52
+		wg.Go(func() {
53
+			lErrChan := errChan
54
+			rChan := rvChan
55
+
56
+			ip, err := getIPAddressPlain(ctx, client, url)
57
+			if err == nil {
58
+				lErrChan = nil
59
+			} else {
60
+				rChan = nil
61
+			}
62
+
63
+			select {
64
+			case <-ctx.Done():
65
+			case lErrChan <- fmt.Errorf("%s: %w", url, err):
66
+			case rChan <- ip:
67
+			}
68
+		})
69
+	}
70
+
71
+	wg.Go(func() {
72
+		defer cancelCause(nil)
73
+
74
+		for {
75
+			select {
76
+			case <-ctx.Done():
77
+				return
78
+			case foundIP := <-rvChan:
79
+				ip = foundIP
80
+				return
81
+			case err := <-errChan:
82
+				errs = append(errs, err)
83
+				if len(errs) == len(getIPServicesPlain) {
84
+					cancelCause(fmt.Errorf(
85
+						"cannot resolve %s address: %w",
86
+						protocol,
87
+						errors.Join(errs...),
88
+					))
89
+				}
90
+			}
91
+		}
92
+	})
93
+
94
+	wg.Wait()
95
+
96
+	if ip != nil {
97
+		return ip, nil
98
+	}
99
+
100
+	return nil, context.Cause(ctx)
101
+}
102
+
103
+func getIPAddressPlain(ctx context.Context, client *http.Client, address string) (net.IP, error) {
104
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, address, nil)
105
+	if err != nil {
106
+		panic(err)
107
+	}
108
+
109
+	req.Header.Add("Accept", "text/plain")
110
+
111
+	resp, err := client.Do(req)
112
+	if err != nil {
113
+		return nil, err
114
+	}
115
+
116
+	defer func() {
117
+		io.Copy(io.Discard, resp.Body) //nolint: errcheck
118
+		resp.Body.Close()              //nolint: errcheck
119
+	}()
120
+
121
+	if resp.StatusCode != http.StatusOK {
122
+		return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
123
+	}
124
+
125
+	data, err := io.ReadAll(resp.Body)
126
+	if err != nil {
127
+		return nil, err
128
+	}
129
+
130
+	data = bytes.TrimSpace(data)
131
+	ip := net.ParseIP(string(data))
132
+
133
+	if ip == nil {
134
+		return nil, errors.New("cannot parse as IP address")
135
+	}
136
+
137
+	return ip, nil
138
+}

+ 16
- 41
internal/cli/run_proxy.go Просмотреть файл

5
 	"fmt"
5
 	"fmt"
6
 	"net"
6
 	"net"
7
 	"os"
7
 	"os"
8
+	"slices"
8
 	"strings"
9
 	"strings"
9
 
10
 
10
 	"github.com/9seconds/mtg/v2/antireplay"
11
 	"github.com/9seconds/mtg/v2/antireplay"
215
 		return
216
 		return
216
 	}
217
 	}
217
 
218
 
218
-	res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw)
219
+	log = log.BindStr("hostname", host)
219
 
220
 
220
-	if res.ResolveErr != nil {
221
-		log.BindStr("hostname", host).
222
-			WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr)
221
+	res, err := runSNICheck(context.Background(), conf, net.DefaultResolver, ntw)
222
+	if err != nil {
223
+		log.WarningError("SNI-DNS check: cannot resolve secret hostname", err)
223
 		return
224
 		return
224
 	}
225
 	}
225
 
226
 
226
-	if !res.PublicIPKnown() {
227
+	if res.OurIP4 == "" && res.OurIP6 == "" {
227
 		log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
228
 		log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
228
 		return
229
 		return
229
 	}
230
 	}
230
 
231
 
231
-	v4Match := res.OurIPv4 == nil || res.IPv4Match
232
-	v6Match := res.OurIPv6 == nil || res.IPv6Match
233
-
234
-	if v4Match && v6Match {
235
-		return
236
-	}
237
-
238
-	resolved := make([]string, 0, len(res.Resolved))
239
-	for _, ip := range res.Resolved {
240
-		resolved = append(resolved, ip.String())
232
+	if len(res.ResolvedIP4) > 0 && !slices.Contains(res.ResolvedIP4, res.OurIP4) {
233
+		log.
234
+			BindStr("public_ip", res.OurIP4).
235
+			BindStr("resolved", strings.Join(res.ResolvedIP4, ",")).
236
+			Warning("SNI-DNS check: address mismatch")
241
 	}
237
 	}
242
 
238
 
243
-	our := ""
244
-	if res.OurIPv4 != nil {
245
-		our = res.OurIPv4.String()
239
+	if len(res.ResolvedIP6) > 0 && !slices.Contains(res.ResolvedIP6, res.OurIP6) {
240
+		log.
241
+			BindStr("public_ip", res.OurIP6).
242
+			BindStr("resolved", strings.Join(res.ResolvedIP6, ",")).
243
+			Warning("SNI-DNS check: address mismatch")
246
 	}
244
 	}
247
-
248
-	if res.OurIPv6 != nil {
249
-		if our != "" {
250
-			our += "/"
251
-		}
252
-
253
-		our += res.OurIPv6.String()
254
-	}
255
-
256
-	entry := log.BindStr("hostname", host).
257
-		BindStr("resolved", strings.Join(resolved, ", ")).
258
-		BindStr("public_ip", our)
259
-
260
-	if res.OurIPv4 != nil {
261
-		entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match))
262
-	}
263
-
264
-	if res.OurIPv6 != nil {
265
-		entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match))
266
-	}
267
-
268
-	entry.Warning("SNI-DNS mismatch: secret hostname does not resolve to this server's public IP. " +
269
-		"DPI may detect and block the proxy. See 'mtg doctor' for details")
270
 }
245
 }
271
 
246
 
272
 func warnDeprecatedDomainFronting(conf *config.Config, log mtglib.Logger) {
247
 func warnDeprecatedDomainFronting(conf *config.Config, log mtglib.Logger) {

+ 56
- 45
internal/cli/sni_check.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
+	"fmt"
5
 	"net"
6
 	"net"
7
+	"sync"
6
 
8
 
7
 	"github.com/9seconds/mtg/v2/internal/config"
9
 	"github.com/9seconds/mtg/v2/internal/config"
8
 	"github.com/9seconds/mtg/v2/mtglib"
10
 	"github.com/9seconds/mtg/v2/mtglib"
9
 )
11
 )
10
 
12
 
11
-// sniCheckResult holds the data gathered while comparing the secret
12
-// hostname's DNS records against this server's public IP addresses.
13
-//
14
-// IPv4Match / IPv6Match report whether a resolved record actually equals the
15
-// corresponding public IP. They are false when that family's public IP could
16
-// not be determined — there is nothing to compare against. Callers decide
17
-// what counts as a clean result from these fields: `mtg doctor` and the
18
-// startup warning apply different rules.
19
 type sniCheckResult struct {
13
 type sniCheckResult struct {
20
-	Resolved   []net.IP
21
-	OurIPv4    net.IP
22
-	OurIPv6    net.IP
23
-	IPv4Match  bool
24
-	IPv6Match  bool
25
-	ResolveErr error
14
+	ResolvedIP4 []string
15
+	ResolvedIP6 []string
16
+	OurIP4      string
17
+	OurIP6      string
26
 }
18
 }
27
 
19
 
28
-// PublicIPKnown reports whether at least one public IP family was detected.
29
-func (r sniCheckResult) PublicIPKnown() bool {
30
-	return r.OurIPv4 != nil || r.OurIPv6 != nil
31
-}
32
-
33
-// runSNICheck resolves conf.Secret.Host and compares the records with this
34
-// server's public IPv4 and IPv6. Public IPs come from config first and fall
35
-// back to on-the-fly detection via ntw. It gathers data only — it does not
36
-// decide success; see sniCheckResult.
37
 func runSNICheck(
20
 func runSNICheck(
38
 	ctx context.Context,
21
 	ctx context.Context,
39
-	resolver *net.Resolver,
40
 	conf *config.Config,
22
 	conf *config.Config,
23
+	resolver *net.Resolver,
41
 	ntw mtglib.Network,
24
 	ntw mtglib.Network,
42
-) sniCheckResult {
25
+) (sniCheckResult, error) {
43
 	res := sniCheckResult{}
26
 	res := sniCheckResult{}
44
 
27
 
28
+	ctx, cancelCause := context.WithCancelCause(ctx)
29
+	defer cancelCause(nil)
30
+
45
 	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
31
 	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
46
 	if err != nil {
32
 	if err != nil {
47
-		res.ResolveErr = err
48
-
49
-		return res
33
+		return res, fmt.Errorf("cannot resolve addresses of %s: %w", conf.Secret.Host, err)
50
 	}
34
 	}
51
 
35
 
52
-	res.Resolved = make([]net.IP, 0, len(addrs))
53
-	for _, a := range addrs {
54
-		res.Resolved = append(res.Resolved, a.IP)
36
+	if len(addrs) == 0 {
37
+		return res, fmt.Errorf("no known addresses for %s", conf.Secret.Host)
55
 	}
38
 	}
56
 
39
 
57
-	res.OurIPv4 = conf.PublicIPv4.Get(nil)
58
-	if res.OurIPv4 == nil {
59
-		res.OurIPv4 = getIP(ntw, "tcp4")
40
+	for _, addr := range addrs {
41
+		if ip := addr.IP.To4(); ip == nil {
42
+			res.ResolvedIP6 = append(res.ResolvedIP6, addr.IP.To16().String())
43
+		} else {
44
+			res.ResolvedIP4 = append(res.ResolvedIP4, ip.String())
45
+		}
60
 	}
46
 	}
61
 
47
 
62
-	res.OurIPv6 = conf.PublicIPv6.Get(nil)
63
-	if res.OurIPv6 == nil {
64
-		res.OurIPv6 = getIP(ntw, "tcp6")
48
+	wg := &sync.WaitGroup{}
49
+
50
+	if len(res.ResolvedIP4) > 0 {
51
+		wg.Go(func() {
52
+			var err error
53
+
54
+			ip := conf.PublicIPv4.Get(nil)
55
+			if ip == nil {
56
+				ip, err = getIP(ctx, ntw, "tcp4")
57
+				if err != nil {
58
+					cancelCause(err)
59
+				}
60
+			}
61
+
62
+			if ip != nil {
63
+				res.OurIP4 = ip.To4().String()
64
+			}
65
+		})
65
 	}
66
 	}
66
 
67
 
67
-	for _, ip := range res.Resolved {
68
-		if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() {
69
-			res.IPv4Match = true
70
-		}
68
+	if len(res.ResolvedIP6) > 0 {
69
+		wg.Go(func() {
70
+			var err error
71
 
71
 
72
-		if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() {
73
-			res.IPv6Match = true
74
-		}
72
+			ip := conf.PublicIPv6.Get(nil)
73
+			if ip == nil {
74
+				ip, err = getIP(ctx, ntw, "tcp6")
75
+				if err != nil {
76
+					cancelCause(err)
77
+				}
78
+			}
79
+
80
+			if ip != nil {
81
+				res.OurIP6 = ip.To16().String()
82
+			}
83
+		})
75
 	}
84
 	}
76
 
85
 
77
-	return res
86
+	wg.Wait()
87
+
88
+	return res, context.Cause(ctx)
78
 }
89
 }

+ 0
- 51
internal/cli/utils.go Просмотреть файл

1
-package cli
2
-
3
-import (
4
-	"context"
5
-	"io"
6
-	"net"
7
-	"net/http"
8
-	"strings"
9
-
10
-	"github.com/9seconds/mtg/v2/essentials"
11
-	"github.com/9seconds/mtg/v2/mtglib"
12
-)
13
-
14
-func getIP(ntw mtglib.Network, protocol string) net.IP {
15
-	dialer := ntw.NativeDialer()
16
-	client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error) {
17
-		conn, err := dialer.DialContext(ctx, protocol, address)
18
-		if err != nil {
19
-			return nil, err
20
-		}
21
-		return essentials.WrapNetConn(conn), err
22
-	})
23
-
24
-	req, err := http.NewRequest(http.MethodGet, "https://ifconfig.co", nil) //nolint: noctx
25
-	if err != nil {
26
-		panic(err)
27
-	}
28
-
29
-	req.Header.Add("Accept", "text/plain")
30
-
31
-	resp, err := client.Do(req)
32
-	if err != nil {
33
-		return nil
34
-	}
35
-
36
-	if resp.StatusCode != http.StatusOK {
37
-		return nil
38
-	}
39
-
40
-	defer func() {
41
-		io.Copy(io.Discard, resp.Body) //nolint: errcheck
42
-		resp.Body.Close()              //nolint: errcheck
43
-	}()
44
-
45
-	data, err := io.ReadAll(resp.Body)
46
-	if err != nil {
47
-		return nil
48
-	}
49
-
50
-	return net.ParseIP(strings.TrimSpace(string(data)))
51
-}

Загрузка…
Отмена
Сохранить