Просмотр исходного кода

Merge pull request #542 from 9seconds/multiple-ip-detectors

master
Sergei Arkhipov 5 дней назад
Родитель
Сommit
d095108334
Аккаунт пользователя с таким Email не найден
7 измененных файлов: 468 добавлений и 186 удалений
  1. 12
    10
      internal/cli/access.go
  2. 63
    39
      internal/cli/doctor.go
  3. 138
    0
      internal/cli/get_ip.go
  4. 16
    41
      internal/cli/run_proxy.go
  5. 43
    45
      internal/cli/sni_check.go
  6. 196
    0
      internal/cli/sni_check_test.go
  7. 0
    51
      internal/cli/utils.go

+ 12
- 10
internal/cli/access.go Просмотреть файл

1
 package cli
1
 package cli
2
 
2
 
3
 import (
3
 import (
4
+	"context"
4
 	"encoding/json"
5
 	"encoding/json"
5
 	"fmt"
6
 	"fmt"
6
 	"net"
7
 	"net"
54
 		return fmt.Errorf("cannot init network: %w", err)
55
 		return fmt.Errorf("cannot init network: %w", err)
55
 	}
56
 	}
56
 
57
 
57
-	wg := &sync.WaitGroup{}
58
+	ctx, cancel := context.WithTimeout(context.Background(), getIPTimeout)
59
+	defer cancel()
58
 
60
 
61
+	wg := &sync.WaitGroup{}
59
 	wg.Go(func() {
62
 	wg.Go(func() {
60
 		ip := a.PublicIPv4
63
 		ip := a.PublicIPv4
64
+
61
 		if ip == nil {
65
 		if ip == nil {
62
 			ip = conf.PublicIPv4.Get(nil)
66
 			ip = conf.PublicIPv4.Get(nil)
63
 		}
67
 		}
68
+
64
 		if ip == nil {
69
 		if ip == nil {
65
-			ip = getIP(ntw, "tcp4")
70
+			ip, _ = getIP(ctx, ntw, "tcp4")
66
 		}
71
 		}
67
 
72
 
68
 		if ip != nil {
73
 		if ip != nil {
69
-			ip = ip.To4()
74
+			resp.IPv4 = a.makeURLs(conf, ip.To4())
70
 		}
75
 		}
71
-
72
-		resp.IPv4 = a.makeURLs(conf, ip)
73
 	})
76
 	})
74
 	wg.Go(func() {
77
 	wg.Go(func() {
75
 		ip := a.PublicIPv6
78
 		ip := a.PublicIPv6
79
+
76
 		if ip == nil {
80
 		if ip == nil {
77
 			ip = conf.PublicIPv6.Get(nil)
81
 			ip = conf.PublicIPv6.Get(nil)
78
 		}
82
 		}
83
+
79
 		if ip == nil {
84
 		if ip == nil {
80
-			ip = getIP(ntw, "tcp6")
85
+			ip, _ = getIP(ctx, ntw, "tcp6")
81
 		}
86
 		}
82
 
87
 
83
 		if ip != nil {
88
 		if ip != nil {
84
-			ip = ip.To16()
89
+			resp.IPv6 = a.makeURLs(conf, ip.To16())
85
 		}
90
 		}
86
-
87
-		resp.IPv6 = a.makeURLs(conf, ip)
88
 	})
91
 	})
89
-
90
 	wg.Wait()
92
 	wg.Wait()
91
 
93
 
92
 	encoder := json.NewEncoder(os.Stdout)
94
 	encoder := json.NewEncoder(os.Stdout)

+ 63
- 39
internal/cli/doctor.go Просмотреть файл

24
 )
24
 )
25
 
25
 
26
 var (
26
 var (
27
+	funcs = template.FuncMap{
28
+		"join": strings.Join,
29
+	}
30
+
27
 	tplError = template.Must(
31
 	tplError = template.Must(
28
-		template.New("").Parse("  ‼️ {{ .description }}: {{ .error }}\n"),
32
+		template.New("").
33
+			Funcs(funcs).
34
+			Parse("  ‼️ {{ .description }}: {{ .error }}\n"),
29
 	)
35
 	)
30
 
36
 
31
 	tplWDeprecatedConfig = template.Must(
37
 	tplWDeprecatedConfig = template.Must(
32
 		template.New("").
38
 		template.New("").
39
+			Funcs(funcs).
33
 			Parse(`  ⚠️ Option {{ .old | printf "%q" }}{{ if .old_section }} from section [{{ .old_section }}]{{ end }} is deprecated and will be removed in v{{ .when }}. Please use {{ .new | printf "%q" }}{{ if .new_section }} in [{{ .new_section }}] section{{ end }} instead.` + "\n"),
40
 			Parse(`  ⚠️ Option {{ .old | printf "%q" }}{{ if .old_section }} from section [{{ .old_section }}]{{ end }} is deprecated and will be removed in v{{ .when }}. Please use {{ .new | printf "%q" }}{{ if .new_section }} in [{{ .new_section }}] section{{ end }} instead.` + "\n"),
34
 	)
41
 	)
35
 
42
 
36
 	tplOTimeSkewness = template.Must(
43
 	tplOTimeSkewness = template.Must(
37
 		template.New("").
44
 		template.New("").
45
+			Funcs(funcs).
38
 			Parse("  ✅ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}\n"),
46
 			Parse("  ✅ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}\n"),
39
 	)
47
 	)
40
 	tplWTimeSkewness = template.Must(
48
 	tplWTimeSkewness = template.Must(
41
 		template.New("").
49
 		template.New("").
50
+			Funcs(funcs).
42
 			Parse("  ⚠️ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. Please check ntp.\n"),
51
 			Parse("  ⚠️ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. Please check ntp.\n"),
43
 	)
52
 	)
44
 	tplETimeSkewness = template.Must(
53
 	tplETimeSkewness = template.Must(
45
 		template.New("").
54
 		template.New("").
55
+			Funcs(funcs).
46
 			Parse("  ❌ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. You will get many rejected connections!\n"),
56
 			Parse("  ❌ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. You will get many rejected connections!\n"),
47
 	)
57
 	)
48
 
58
 
49
 	tplODCConnect = template.Must(
59
 	tplODCConnect = template.Must(
50
-		template.New("").Parse("  ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"),
60
+		template.New("").
61
+			Funcs(funcs).
62
+			Parse("  ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"),
51
 	)
63
 	)
52
 	tplEDCConnect = template.Must(
64
 	tplEDCConnect = template.Must(
53
-		template.New("").Parse("  ❌ DC {{ .dc }}: {{ .error }}\n"),
65
+		template.New("").
66
+			Funcs(funcs).
67
+			Parse("  ❌ DC {{ .dc }}: {{ .error }}\n"),
54
 	)
68
 	)
55
 
69
 
56
 	tplODNSSNIMatch = template.Must(
70
 	tplODNSSNIMatch = template.Must(
57
-		template.New("").Parse("  ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"),
71
+		template.New("").
72
+			Funcs(funcs).
73
+			Parse("  ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"),
58
 	)
74
 	)
59
 	tplEDNSSNIMatch = template.Must(
75
 	tplEDNSSNIMatch = template.Must(
60
-		template.New("").Parse("  ❌ Hostname {{ .hostname }} {{ if .resolved }}resolves to {{ .resolved }}, but the proxy's public IP is {{ if .ip4 }}{{ .ip4 }}{{ else }}<not detected>{{ end }} (IPv4) / {{ if .ip6 }}{{ .ip6 }}{{ else }}<not detected>{{ end }} (IPv6) — none of the resolved addresses match{{ else }}cannot be resolved to any host{{ end }}\n"),
76
+		template.New("").
77
+			Funcs(funcs).
78
+			Parse(`  ❌ Hostname {{ .hostname }} resolves to {{ join ", " .resolved }} but public IP is {{ .ip }}` + "\n"),
61
 	)
79
 	)
62
 
80
 
63
 	tplOFrontingDomain = template.Must(
81
 	tplOFrontingDomain = template.Must(
64
-		template.New("").Parse("  ✅ {{ .address }} is reachable\n"),
82
+		template.New("").
83
+			Funcs(funcs).
84
+			Parse("  ✅ {{ .address }} is reachable\n"),
65
 	)
85
 	)
66
 	tplEFrontingDomain = template.Must(
86
 	tplEFrontingDomain = template.Must(
67
-		template.New("").Parse("  ❌ {{ .address }}: {{ .error }}\n"),
87
+		template.New("").
88
+			Funcs(funcs).
89
+			Parse("  ❌ {{ .address }}: {{ .error }}\n"),
68
 	)
90
 	)
69
 )
91
 )
70
 
92
 
71
 type Doctor struct {
93
 type Doctor struct {
72
 	conf *config.Config
94
 	conf *config.Config
73
 
95
 
74
-	ConfigPath      string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"` //nolint: lll
96
+	ConfigPath      string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"`                                                                       //nolint: lll
75
 	SkipNativeCheck bool   `kong:"help='Skip the native network connectivity check (useful when proxy chaining is configured and direct egress is not expected to work).',name='skip-native-check'"` //nolint: lll
97
 	SkipNativeCheck bool   `kong:"help='Skip the native network connectivity check (useful when proxy chaining is configured and direct egress is not expected to work).',name='skip-native-check'"` //nolint: lll
76
 }
98
 }
77
 
99
 
371
 }
393
 }
372
 
394
 
373
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
395
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
374
-	res := runSNICheck(context.Background(), resolver, d.conf, ntw)
375
-
376
-	if res.ResolveErr != nil {
396
+	res, err := runSNICheck(context.Background(), d.conf, resolver, ntw)
397
+	if err != nil {
377
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
398
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
378
 			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
399
 			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
379
-			"error":       res.ResolveErr,
400
+			"error":       err,
380
 		})
401
 		})
381
 		return false
402
 		return false
382
 	}
403
 	}
383
 
404
 
384
-	if !res.PublicIPKnown() {
405
+	if res.OurIP4 == "" && res.OurIP6 == "" {
385
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
406
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
386
 			"description": "cannot detect public IP address",
407
 			"description": "cannot detect public IP address",
387
 			"error":       errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
408
 			"error":       errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
389
 		return false
410
 		return false
390
 	}
411
 	}
391
 
412
 
392
-	if res.IPv4Match || res.IPv6Match {
393
-		var matched net.IP
413
+	ok := true
394
 
414
 
395
-		for _, ip := range res.Resolved {
396
-			if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) ||
397
-				(res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) {
398
-				matched = ip
399
-				break
400
-			}
415
+	if len(res.ResolvedIP4) > 0 {
416
+		if slices.Contains(res.ResolvedIP4, res.OurIP4) {
417
+			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
418
+				"ip":       res.OurIP4,
419
+				"hostname": d.conf.Secret.Host,
420
+			})
421
+		} else {
422
+			tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
423
+				"ip":       res.OurIP4,
424
+				"resolved": res.ResolvedIP4,
425
+				"hostname": d.conf.Secret.Host,
426
+			})
427
+			ok = false
401
 		}
428
 		}
402
-
403
-		tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
404
-			"ip":       matched,
405
-			"hostname": d.conf.Secret.Host,
406
-		})
407
-		return true
408
 	}
429
 	}
409
-
410
-	strAddresses := make([]string, 0, len(res.Resolved))
411
-	for _, ip := range res.Resolved {
412
-		strAddresses = append(strAddresses, `"`+ip.String()+`"`)
430
+	if len(res.ResolvedIP6) > 0 {
431
+		if slices.Contains(res.ResolvedIP6, res.OurIP6) {
432
+			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
433
+				"ip":       res.OurIP6,
434
+				"hostname": d.conf.Secret.Host,
435
+			})
436
+		} else {
437
+			tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
438
+				"ip":       res.OurIP6,
439
+				"resolved": res.ResolvedIP6,
440
+				"hostname": d.conf.Secret.Host,
441
+			})
442
+			ok = false
443
+		}
413
 	}
444
 	}
414
 
445
 
415
-	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
416
-		"hostname": d.conf.Secret.Host,
417
-		"resolved": strings.Join(strAddresses, ", "),
418
-		"ip4":      res.OurIPv4,
419
-		"ip6":      res.OurIPv6,
420
-	})
421
-
422
-	return false
446
+	return ok
423
 }
447
 }

+ 138
- 0
internal/cli/get_ip.go Просмотреть файл

1
+package cli
2
+
3
+import (
4
+	"bytes"
5
+	"context"
6
+	"errors"
7
+	"fmt"
8
+	"io"
9
+	"net"
10
+	"net/http"
11
+	"sync"
12
+	"time"
13
+
14
+	"github.com/9seconds/mtg/v2/essentials"
15
+	"github.com/9seconds/mtg/v2/mtglib"
16
+)
17
+
18
+const (
19
+	getIPTimeout = 5 * time.Second
20
+)
21
+
22
+var getIPServicesPlain = []string{
23
+	"https://ifconfig.co",
24
+	"https://ifconfig.me",
25
+	"https://api.ipify.org",
26
+	"https://ipecho.net/plain",
27
+}
28
+
29
+func getIP(ctx context.Context, ntw mtglib.Network, protocol string) (net.IP, error) {
30
+	ctx, cancel := context.WithTimeout(ctx, getIPTimeout)
31
+	defer cancel()
32
+
33
+	ctx, cancelCause := context.WithCancelCause(ctx)
34
+	defer cancelCause(nil)
35
+
36
+	var ip net.IP
37
+
38
+	rvChan := make(chan net.IP)
39
+	errChan := make(chan error)
40
+	errs := []error{}
41
+	wg := &sync.WaitGroup{}
42
+	dialer := ntw.NativeDialer()
43
+	client := ntw.MakeHTTPClient(func(_ context.Context, network, address string) (essentials.Conn, error) {
44
+		conn, err := dialer.DialContext(ctx, protocol, address)
45
+		if err != nil {
46
+			return nil, err
47
+		}
48
+		return essentials.WrapNetConn(conn), err
49
+	})
50
+
51
+	for _, url := range getIPServicesPlain {
52
+		wg.Go(func() {
53
+			lErrChan := errChan
54
+			rChan := rvChan
55
+
56
+			ip, err := getIPAddressPlain(ctx, client, url)
57
+			if err == nil {
58
+				lErrChan = nil
59
+			} else {
60
+				rChan = nil
61
+			}
62
+
63
+			select {
64
+			case <-ctx.Done():
65
+			case lErrChan <- fmt.Errorf("%s: %w", url, err):
66
+			case rChan <- ip:
67
+			}
68
+		})
69
+	}
70
+
71
+	wg.Go(func() {
72
+		defer cancelCause(nil)
73
+
74
+		for {
75
+			select {
76
+			case <-ctx.Done():
77
+				return
78
+			case foundIP := <-rvChan:
79
+				ip = foundIP
80
+				return
81
+			case err := <-errChan:
82
+				errs = append(errs, err)
83
+				if len(errs) == len(getIPServicesPlain) {
84
+					cancelCause(fmt.Errorf(
85
+						"cannot resolve %s address: %w",
86
+						protocol,
87
+						errors.Join(errs...),
88
+					))
89
+				}
90
+			}
91
+		}
92
+	})
93
+
94
+	wg.Wait()
95
+
96
+	if ip != nil {
97
+		return ip, nil
98
+	}
99
+
100
+	return nil, context.Cause(ctx)
101
+}
102
+
103
+func getIPAddressPlain(ctx context.Context, client *http.Client, address string) (net.IP, error) {
104
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, address, nil)
105
+	if err != nil {
106
+		panic(err)
107
+	}
108
+
109
+	req.Header.Add("Accept", "text/plain")
110
+
111
+	resp, err := client.Do(req)
112
+	if err != nil {
113
+		return nil, err
114
+	}
115
+
116
+	defer func() {
117
+		io.Copy(io.Discard, resp.Body) //nolint: errcheck
118
+		resp.Body.Close()              //nolint: errcheck
119
+	}()
120
+
121
+	if resp.StatusCode != http.StatusOK {
122
+		return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
123
+	}
124
+
125
+	data, err := io.ReadAll(resp.Body)
126
+	if err != nil {
127
+		return nil, err
128
+	}
129
+
130
+	data = bytes.TrimSpace(data)
131
+	ip := net.ParseIP(string(data))
132
+
133
+	if ip == nil {
134
+		return nil, errors.New("cannot parse as IP address")
135
+	}
136
+
137
+	return ip, nil
138
+}

+ 16
- 41
internal/cli/run_proxy.go Просмотреть файл

5
 	"fmt"
5
 	"fmt"
6
 	"net"
6
 	"net"
7
 	"os"
7
 	"os"
8
+	"slices"
8
 	"strings"
9
 	"strings"
9
 
10
 
10
 	"github.com/9seconds/mtg/v2/antireplay"
11
 	"github.com/9seconds/mtg/v2/antireplay"
215
 		return
216
 		return
216
 	}
217
 	}
217
 
218
 
218
-	res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw)
219
+	log = log.BindStr("hostname", host)
219
 
220
 
220
-	if res.ResolveErr != nil {
221
-		log.BindStr("hostname", host).
222
-			WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr)
221
+	res, err := runSNICheck(context.Background(), conf, net.DefaultResolver, ntw)
222
+	if err != nil {
223
+		log.WarningError("SNI-DNS check: cannot resolve secret hostname", err)
223
 		return
224
 		return
224
 	}
225
 	}
225
 
226
 
226
-	if !res.PublicIPKnown() {
227
+	if res.OurIP4 == "" && res.OurIP6 == "" {
227
 		log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
228
 		log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
228
 		return
229
 		return
229
 	}
230
 	}
230
 
231
 
231
-	v4Match := res.OurIPv4 == nil || res.IPv4Match
232
-	v6Match := res.OurIPv6 == nil || res.IPv6Match
233
-
234
-	if v4Match && v6Match {
235
-		return
236
-	}
237
-
238
-	resolved := make([]string, 0, len(res.Resolved))
239
-	for _, ip := range res.Resolved {
240
-		resolved = append(resolved, ip.String())
232
+	if len(res.ResolvedIP4) > 0 && !slices.Contains(res.ResolvedIP4, res.OurIP4) {
233
+		log.
234
+			BindStr("public_ip", res.OurIP4).
235
+			BindStr("resolved", strings.Join(res.ResolvedIP4, ",")).
236
+			Warning("SNI-DNS check: address mismatch")
241
 	}
237
 	}
242
 
238
 
243
-	our := ""
244
-	if res.OurIPv4 != nil {
245
-		our = res.OurIPv4.String()
239
+	if len(res.ResolvedIP6) > 0 && !slices.Contains(res.ResolvedIP6, res.OurIP6) {
240
+		log.
241
+			BindStr("public_ip", res.OurIP6).
242
+			BindStr("resolved", strings.Join(res.ResolvedIP6, ",")).
243
+			Warning("SNI-DNS check: address mismatch")
246
 	}
244
 	}
247
-
248
-	if res.OurIPv6 != nil {
249
-		if our != "" {
250
-			our += "/"
251
-		}
252
-
253
-		our += res.OurIPv6.String()
254
-	}
255
-
256
-	entry := log.BindStr("hostname", host).
257
-		BindStr("resolved", strings.Join(resolved, ", ")).
258
-		BindStr("public_ip", our)
259
-
260
-	if res.OurIPv4 != nil {
261
-		entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match))
262
-	}
263
-
264
-	if res.OurIPv6 != nil {
265
-		entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match))
266
-	}
267
-
268
-	entry.Warning("SNI-DNS mismatch: secret hostname does not resolve to this server's public IP. " +
269
-		"DPI may detect and block the proxy. See 'mtg doctor' for details")
270
 }
245
 }
271
 
246
 
272
 func warnDeprecatedDomainFronting(conf *config.Config, log mtglib.Logger) {
247
 func warnDeprecatedDomainFronting(conf *config.Config, log mtglib.Logger) {

+ 43
- 45
internal/cli/sni_check.go Просмотреть файл

2
 
2
 
3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
+	"fmt"
5
 	"net"
6
 	"net"
7
+	"sync"
6
 
8
 
7
 	"github.com/9seconds/mtg/v2/internal/config"
9
 	"github.com/9seconds/mtg/v2/internal/config"
8
 	"github.com/9seconds/mtg/v2/mtglib"
10
 	"github.com/9seconds/mtg/v2/mtglib"
9
 )
11
 )
10
 
12
 
11
-// sniCheckResult holds the data gathered while comparing the secret
12
-// hostname's DNS records against this server's public IP addresses.
13
-//
14
-// IPv4Match / IPv6Match report whether a resolved record actually equals the
15
-// corresponding public IP. They are false when that family's public IP could
16
-// not be determined — there is nothing to compare against. Callers decide
17
-// what counts as a clean result from these fields: `mtg doctor` and the
18
-// startup warning apply different rules.
19
 type sniCheckResult struct {
13
 type sniCheckResult struct {
20
-	Resolved   []net.IP
21
-	OurIPv4    net.IP
22
-	OurIPv6    net.IP
23
-	IPv4Match  bool
24
-	IPv6Match  bool
25
-	ResolveErr error
14
+	ResolvedIP4 []string
15
+	ResolvedIP6 []string
16
+	OurIP4      string
17
+	OurIP6      string
26
 }
18
 }
27
 
19
 
28
-// PublicIPKnown reports whether at least one public IP family was detected.
29
-func (r sniCheckResult) PublicIPKnown() bool {
30
-	return r.OurIPv4 != nil || r.OurIPv6 != nil
31
-}
32
-
33
-// runSNICheck resolves conf.Secret.Host and compares the records with this
34
-// server's public IPv4 and IPv6. Public IPs come from config first and fall
35
-// back to on-the-fly detection via ntw. It gathers data only — it does not
36
-// decide success; see sniCheckResult.
37
 func runSNICheck(
20
 func runSNICheck(
38
 	ctx context.Context,
21
 	ctx context.Context,
39
-	resolver *net.Resolver,
40
 	conf *config.Config,
22
 	conf *config.Config,
23
+	resolver *net.Resolver,
41
 	ntw mtglib.Network,
24
 	ntw mtglib.Network,
42
-) sniCheckResult {
25
+) (sniCheckResult, error) {
43
 	res := sniCheckResult{}
26
 	res := sniCheckResult{}
44
 
27
 
45
 	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
28
 	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
46
 	if err != nil {
29
 	if err != nil {
47
-		res.ResolveErr = err
48
-
49
-		return res
30
+		return res, fmt.Errorf("cannot resolve addresses of %s: %w", conf.Secret.Host, err)
50
 	}
31
 	}
51
 
32
 
52
-	res.Resolved = make([]net.IP, 0, len(addrs))
53
-	for _, a := range addrs {
54
-		res.Resolved = append(res.Resolved, a.IP)
33
+	if len(addrs) == 0 {
34
+		return res, fmt.Errorf("no known addresses for %s", conf.Secret.Host)
55
 	}
35
 	}
56
 
36
 
57
-	res.OurIPv4 = conf.PublicIPv4.Get(nil)
58
-	if res.OurIPv4 == nil {
59
-		res.OurIPv4 = getIP(ntw, "tcp4")
37
+	for _, addr := range addrs {
38
+		if ip := addr.IP.To4(); ip == nil {
39
+			res.ResolvedIP6 = append(res.ResolvedIP6, addr.IP.To16().String())
40
+		} else {
41
+			res.ResolvedIP4 = append(res.ResolvedIP4, ip.String())
42
+		}
60
 	}
43
 	}
61
 
44
 
62
-	res.OurIPv6 = conf.PublicIPv6.Get(nil)
63
-	if res.OurIPv6 == nil {
64
-		res.OurIPv6 = getIP(ntw, "tcp6")
45
+	wg := &sync.WaitGroup{}
46
+
47
+	if len(res.ResolvedIP4) > 0 {
48
+		wg.Go(func() {
49
+			ip := conf.PublicIPv4.Get(nil)
50
+			if ip == nil {
51
+				ip, _ = getIP(ctx, ntw, "tcp4")
52
+			}
53
+
54
+			if ip != nil {
55
+				res.OurIP4 = ip.To4().String()
56
+			}
57
+		})
65
 	}
58
 	}
66
 
59
 
67
-	for _, ip := range res.Resolved {
68
-		if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() {
69
-			res.IPv4Match = true
70
-		}
60
+	if len(res.ResolvedIP6) > 0 {
61
+		wg.Go(func() {
62
+			ip := conf.PublicIPv6.Get(nil)
63
+			if ip == nil {
64
+				ip, _ = getIP(ctx, ntw, "tcp6")
65
+			}
71
 
66
 
72
-		if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() {
73
-			res.IPv6Match = true
74
-		}
67
+			if ip != nil {
68
+				res.OurIP6 = ip.To16().String()
69
+			}
70
+		})
75
 	}
71
 	}
76
 
72
 
77
-	return res
73
+	wg.Wait()
74
+
75
+	return res, nil
78
 }
76
 }

+ 196
- 0
internal/cli/sni_check_test.go Просмотреть файл

1
+package cli
2
+
3
+import (
4
+	"context"
5
+	"io"
6
+	"net"
7
+	"net/http"
8
+	"strings"
9
+	"testing"
10
+
11
+	"github.com/9seconds/mtg/v2/essentials"
12
+	"github.com/9seconds/mtg/v2/internal/config"
13
+	"github.com/9seconds/mtg/v2/mtglib"
14
+	"github.com/stretchr/testify/require"
15
+	"golang.org/x/net/dns/dnsmessage"
16
+)
17
+
18
+// startSNITestDNS spins up a loopback UDP resolver that answers every query
19
+// with the given A and AAAA records, so runSNICheck sees a dual-stack secret
20
+// host without touching the real network. It returns a *net.Resolver wired to
21
+// it.
22
+func startSNITestDNS(t *testing.T, a, aaaa net.IP) *net.Resolver {
23
+	t.Helper()
24
+
25
+	pc, err := net.ListenPacket("udp", "127.0.0.1:0")
26
+	require.NoError(t, err)
27
+	t.Cleanup(func() { pc.Close() }) //nolint: errcheck
28
+
29
+	go func() {
30
+		buf := make([]byte, 512)
31
+
32
+		for {
33
+			n, addr, err := pc.ReadFrom(buf)
34
+			if err != nil {
35
+				return
36
+			}
37
+
38
+			var parser dnsmessage.Parser
39
+
40
+			hdr, err := parser.Start(buf[:n])
41
+			if err != nil {
42
+				continue
43
+			}
44
+
45
+			question, err := parser.Question()
46
+			if err != nil {
47
+				continue
48
+			}
49
+
50
+			builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{
51
+				ID:                 hdr.ID,
52
+				Response:           true,
53
+				RecursionAvailable: true,
54
+			})
55
+			builder.EnableCompression()
56
+			_ = builder.StartQuestions()
57
+			_ = builder.Question(question)
58
+			_ = builder.StartAnswers()
59
+
60
+			rh := dnsmessage.ResourceHeader{
61
+				Name:  question.Name,
62
+				Class: dnsmessage.ClassINET,
63
+				TTL:   60,
64
+			}
65
+
66
+			switch question.Type {
67
+			case dnsmessage.TypeA:
68
+				rh.Type = dnsmessage.TypeA
69
+				var v4 [4]byte
70
+				copy(v4[:], a.To4())
71
+				_ = builder.AResource(rh, dnsmessage.AResource{A: v4})
72
+			case dnsmessage.TypeAAAA:
73
+				rh.Type = dnsmessage.TypeAAAA
74
+				var v6 [16]byte
75
+				copy(v6[:], aaaa.To16())
76
+				_ = builder.AAAAResource(rh, dnsmessage.AAAAResource{AAAA: v6})
77
+			}
78
+
79
+			msg, err := builder.Finish()
80
+			if err != nil {
81
+				continue
82
+			}
83
+
84
+			pc.WriteTo(msg, addr) //nolint: errcheck
85
+		}
86
+	}()
87
+
88
+	dnsAddr := pc.LocalAddr().String()
89
+
90
+	return &net.Resolver{
91
+		PreferGo: true,
92
+		Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
93
+			var d net.Dialer
94
+
95
+			return d.DialContext(ctx, "udp", dnsAddr)
96
+		},
97
+	}
98
+}
99
+
100
+// ipv4OnlyEgressNetwork fakes mtglib.Network so that public-IP detection
101
+// succeeds over tcp4 and fails over tcp6 — the classic IPv4-only-egress
102
+// server. getIP's per-protocol dial is routed at a loopback listener: a tcp4
103
+// dial to 127.0.0.1 connects, a tcp6 dial to the same address fails ("no
104
+// suitable address"), so we exercise the real getIP code path without the
105
+// internet.
106
+type ipv4OnlyEgressNetwork struct {
107
+	listenerAddr string
108
+	detectedV4   string
109
+}
110
+
111
+func (n *ipv4OnlyEgressNetwork) Dial(_, _ string) (essentials.Conn, error) {
112
+	panic("unused")
113
+}
114
+
115
+func (n *ipv4OnlyEgressNetwork) DialContext(_ context.Context, _, _ string) (essentials.Conn, error) {
116
+	panic("unused")
117
+}
118
+
119
+func (n *ipv4OnlyEgressNetwork) NativeDialer() *net.Dialer {
120
+	return &net.Dialer{}
121
+}
122
+
123
+func (n *ipv4OnlyEgressNetwork) MakeHTTPClient(
124
+	dialFunc func(ctx context.Context, network, address string) (essentials.Conn, error),
125
+) *http.Client {
126
+	return &http.Client{
127
+		Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
128
+			conn, err := dialFunc(req.Context(), "tcp", n.listenerAddr)
129
+			if err != nil {
130
+				return nil, err
131
+			}
132
+
133
+			conn.Close() //nolint: errcheck
134
+
135
+			return &http.Response{
136
+				StatusCode: http.StatusOK,
137
+				Body:       io.NopCloser(strings.NewReader(n.detectedV4)),
138
+				Header:     make(http.Header),
139
+			}, nil
140
+		}),
141
+	}
142
+}
143
+
144
+type roundTripFunc func(*http.Request) (*http.Response, error)
145
+
146
+func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
147
+	return f(r)
148
+}
149
+
150
+// TestRunSNICheckIPv4OnlyEgressGraceful reproduces the #529/#542 regression:
151
+// a dual-stack secret host on a server whose IPv6 egress is down. The tcp6
152
+// public-IP detection fails, but the tcp4 detection succeeds and matches the
153
+// host's A record, so the SNI check must NOT report a hard error — one
154
+// family being undetectable is graceful degradation, not failure.
155
+func TestRunSNICheckIPv4OnlyEgressGraceful(t *testing.T) {
156
+	const ourV4 = "192.0.2.4" // RFC 5737 TEST-NET-1
157
+
158
+	resolver := startSNITestDNS(t, net.ParseIP(ourV4), net.ParseIP("2001:db8::1")) // RFC 3849 doc range
159
+
160
+	// Loopback target for getIP's dial. Keep it the IPv4 literal 127.0.0.1: a
161
+	// "tcp6" dial to it fails deterministically ("no suitable address") on any
162
+	// host regardless of IPv6 connectivity, which is what makes tcp6 detection
163
+	// fail here. Do not replace with a real ::1 setup — that reintroduces flake.
164
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
165
+	require.NoError(t, err)
166
+	t.Cleanup(func() { listener.Close() }) //nolint: errcheck
167
+
168
+	go func() {
169
+		for {
170
+			conn, err := listener.Accept()
171
+			if err != nil {
172
+				return
173
+			}
174
+
175
+			conn.Close() //nolint: errcheck
176
+		}
177
+	}()
178
+
179
+	ntw := &ipv4OnlyEgressNetwork{
180
+		listenerAddr: listener.Addr().String(),
181
+		detectedV4:   ourV4,
182
+	}
183
+
184
+	conf := &config.Config{}
185
+	conf.Secret.Host = "secret-host.test"
186
+
187
+	res, err := runSNICheck(context.Background(), conf, resolver, ntw)
188
+
189
+	// The load-bearing assertion: a single family's detection failure must not
190
+	// poison the whole result. Before the fix this returns a non-nil error.
191
+	require.NoError(t, err)
192
+	require.Equal(t, ourV4, res.OurIP4, "IPv4 public IP should be detected and match the A record")
193
+	require.Empty(t, res.OurIP6, "IPv6 is undetectable on IPv4-only egress; must degrade, not error")
194
+}
195
+
196
+var _ mtglib.Network = (*ipv4OnlyEgressNetwork)(nil)

+ 0
- 51
internal/cli/utils.go Просмотреть файл

1
-package cli
2
-
3
-import (
4
-	"context"
5
-	"io"
6
-	"net"
7
-	"net/http"
8
-	"strings"
9
-
10
-	"github.com/9seconds/mtg/v2/essentials"
11
-	"github.com/9seconds/mtg/v2/mtglib"
12
-)
13
-
14
-func getIP(ntw mtglib.Network, protocol string) net.IP {
15
-	dialer := ntw.NativeDialer()
16
-	client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error) {
17
-		conn, err := dialer.DialContext(ctx, protocol, address)
18
-		if err != nil {
19
-			return nil, err
20
-		}
21
-		return essentials.WrapNetConn(conn), err
22
-	})
23
-
24
-	req, err := http.NewRequest(http.MethodGet, "https://ifconfig.co", nil) //nolint: noctx
25
-	if err != nil {
26
-		panic(err)
27
-	}
28
-
29
-	req.Header.Add("Accept", "text/plain")
30
-
31
-	resp, err := client.Do(req)
32
-	if err != nil {
33
-		return nil
34
-	}
35
-
36
-	if resp.StatusCode != http.StatusOK {
37
-		return nil
38
-	}
39
-
40
-	defer func() {
41
-		io.Copy(io.Discard, resp.Body) //nolint: errcheck
42
-		resp.Body.Close()              //nolint: errcheck
43
-	}()
44
-
45
-	data, err := io.ReadAll(resp.Body)
46
-	if err != nil {
47
-		return nil
48
-	}
49
-
50
-	return net.ParseIP(strings.TrimSpace(string(data)))
51
-}

Загрузка…
Отмена
Сохранить