瀏覽代碼

internal/cli: consolidate duplicated SNI-DNS check

`doctor`'s checkSecretHost and the proxy-startup warnSNIMismatch each
carried their own copy of the same logic: resolve the secret hostname,
determine the server's public IPv4/IPv6 (config first, getIP fallback),
and compare the two sets.

Extract that data-gathering into runSNICheck (internal/cli/sni_check.go),
returning an sniCheckResult. The success decision stays with each caller
because the rules genuinely differ — `doctor` reports OK when any family
matches, while the startup warning requires every detected family to
match — so only the gathering is shared, not the verdict.

No behavior change: both callers produce byte-identical output and the
same return values as before.
pull/528/head
Alexey Dolotov 6 天之前
父節點
當前提交
9593becc2a
共有 3 個檔案被更改,包括 120 行新增61 行删除
  1. 26
    26
      internal/cli/doctor.go
  2. 16
    35
      internal/cli/run_proxy.go
  3. 78
    0
      internal/cli/sni_check.go

+ 26
- 26
internal/cli/doctor.go 查看文件

@@ -361,26 +361,17 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
361 361
 }
362 362
 
363 363
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
364
-	addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host)
365
-	if err != nil {
364
+	res := runSNICheck(context.Background(), resolver, d.conf, ntw)
365
+
366
+	if res.ResolveErr != nil {
366 367
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
367 368
 			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
368
-			"error":       err,
369
+			"error":       res.ResolveErr,
369 370
 		})
370 371
 		return false
371 372
 	}
372 373
 
373
-	ourIP4 := d.conf.PublicIPv4.Get(nil)
374
-	if ourIP4 == nil {
375
-		ourIP4 = getIP(ntw, "tcp4")
376
-	}
377
-
378
-	ourIP6 := d.conf.PublicIPv6.Get(nil)
379
-	if ourIP6 == nil {
380
-		ourIP6 = getIP(ntw, "tcp6")
381
-	}
382
-
383
-	if ourIP4 == nil && ourIP6 == nil {
374
+	if !res.PublicIPKnown() {
384 375
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
385 376
 			"description": "cannot detect public IP address",
386 377
 			"error":       errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
@@ -388,25 +379,34 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo
388 379
 		return false
389 380
 	}
390 381
 
391
-	strAddresses := []string{}
392
-	for _, value := range addresses {
393
-		if (ourIP4 != nil && value.IP.String() == ourIP4.String()) ||
394
-			(ourIP6 != nil && value.IP.String() == ourIP6.String()) {
395
-			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
396
-				"ip":       value.IP,
397
-				"hostname": d.conf.Secret.Host,
398
-			})
399
-			return true
382
+	if res.IPv4Match || res.IPv6Match {
383
+		var matched net.IP
384
+
385
+		for _, ip := range res.Resolved {
386
+			if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) ||
387
+				(res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) {
388
+				matched = ip
389
+				break
390
+			}
400 391
 		}
401 392
 
402
-		strAddresses = append(strAddresses, `"`+value.IP.String()+`"`)
393
+		tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
394
+			"ip":       matched,
395
+			"hostname": d.conf.Secret.Host,
396
+		})
397
+		return true
398
+	}
399
+
400
+	strAddresses := make([]string, 0, len(res.Resolved))
401
+	for _, ip := range res.Resolved {
402
+		strAddresses = append(strAddresses, `"`+ip.String()+`"`)
403 403
 	}
404 404
 
405 405
 	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
406 406
 		"hostname": d.conf.Secret.Host,
407 407
 		"resolved": strings.Join(strAddresses, ", "),
408
-		"ip4":      ourIP4,
409
-		"ip6":      ourIP6,
408
+		"ip4":      res.OurIPv4,
409
+		"ip6":      res.OurIPv6,
410 410
 	})
411 411
 
412 412
 	return false

+ 16
- 35
internal/cli/run_proxy.go 查看文件

@@ -215,72 +215,53 @@ func warnSNIMismatch(conf *config.Config, ntw mtglib.Network, log mtglib.Logger)
215 215
 		return
216 216
 	}
217 217
 
218
-	addresses, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
219
-	if err != nil {
218
+	res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw)
219
+
220
+	if res.ResolveErr != nil {
220 221
 		log.BindStr("hostname", host).
221
-			WarningError("SNI-DNS check: cannot resolve secret hostname", err)
222
+			WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr)
222 223
 		return
223 224
 	}
224 225
 
225
-	ourIP4 := conf.PublicIPv4.Get(nil)
226
-	if ourIP4 == nil {
227
-		ourIP4 = getIP(ntw, "tcp4")
228
-	}
229
-
230
-	ourIP6 := conf.PublicIPv6.Get(nil)
231
-	if ourIP6 == nil {
232
-		ourIP6 = getIP(ntw, "tcp6")
233
-	}
234
-
235
-	if ourIP4 == nil && ourIP6 == nil {
226
+	if !res.PublicIPKnown() {
236 227
 		log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
237 228
 		return
238 229
 	}
239 230
 
240
-	v4Match := ourIP4 == nil
241
-	v6Match := ourIP6 == nil
242
-
243
-	for _, addr := range addresses {
244
-		if ourIP4 != nil && addr.IP.String() == ourIP4.String() {
245
-			v4Match = true
246
-		}
247
-
248
-		if ourIP6 != nil && addr.IP.String() == ourIP6.String() {
249
-			v6Match = true
250
-		}
251
-	}
231
+	v4Match := res.OurIPv4 == nil || res.IPv4Match
232
+	v6Match := res.OurIPv6 == nil || res.IPv6Match
252 233
 
253 234
 	if v4Match && v6Match {
254 235
 		return
255 236
 	}
256 237
 
257
-	resolved := make([]string, 0, len(addresses))
258
-	for _, addr := range addresses {
259
-		resolved = append(resolved, addr.IP.String())
238
+	resolved := make([]string, 0, len(res.Resolved))
239
+	for _, ip := range res.Resolved {
240
+		resolved = append(resolved, ip.String())
260 241
 	}
261 242
 
262 243
 	our := ""
263
-	if ourIP4 != nil {
264
-		our = ourIP4.String()
244
+	if res.OurIPv4 != nil {
245
+		our = res.OurIPv4.String()
265 246
 	}
266 247
 
267
-	if ourIP6 != nil {
248
+	if res.OurIPv6 != nil {
268 249
 		if our != "" {
269 250
 			our += "/"
270 251
 		}
271 252
 
272
-		our += ourIP6.String()
253
+		our += res.OurIPv6.String()
273 254
 	}
274 255
 
275 256
 	entry := log.BindStr("hostname", host).
276 257
 		BindStr("resolved", strings.Join(resolved, ", ")).
277 258
 		BindStr("public_ip", our)
278 259
 
279
-	if ourIP4 != nil {
260
+	if res.OurIPv4 != nil {
280 261
 		entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match))
281 262
 	}
282 263
 
283
-	if ourIP6 != nil {
264
+	if res.OurIPv6 != nil {
284 265
 		entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match))
285 266
 	}
286 267
 

+ 78
- 0
internal/cli/sni_check.go 查看文件

@@ -0,0 +1,78 @@
1
+package cli
2
+
3
+import (
4
+	"context"
5
+	"net"
6
+
7
+	"github.com/9seconds/mtg/v2/internal/config"
8
+	"github.com/9seconds/mtg/v2/mtglib"
9
+)
10
+
11
+// sniCheckResult holds the data gathered while comparing the secret
12
+// hostname's DNS records against this server's public IP addresses.
13
+//
14
+// IPv4Match / IPv6Match report whether a resolved record actually equals the
15
+// corresponding public IP. They are false when that family's public IP could
16
+// not be determined — there is nothing to compare against. Callers decide
17
+// what counts as a clean result from these fields: `mtg doctor` and the
18
+// startup warning apply different rules.
19
+type sniCheckResult struct {
20
+	Resolved   []net.IP
21
+	OurIPv4    net.IP
22
+	OurIPv6    net.IP
23
+	IPv4Match  bool
24
+	IPv6Match  bool
25
+	ResolveErr error
26
+}
27
+
28
+// PublicIPKnown reports whether at least one public IP family was detected.
29
+func (r sniCheckResult) PublicIPKnown() bool {
30
+	return r.OurIPv4 != nil || r.OurIPv6 != nil
31
+}
32
+
33
+// runSNICheck resolves conf.Secret.Host and compares the records with this
34
+// server's public IPv4 and IPv6. Public IPs come from config first and fall
35
+// back to on-the-fly detection via ntw. It gathers data only — it does not
36
+// decide success; see sniCheckResult.
37
+func runSNICheck(
38
+	ctx context.Context,
39
+	resolver *net.Resolver,
40
+	conf *config.Config,
41
+	ntw mtglib.Network,
42
+) sniCheckResult {
43
+	res := sniCheckResult{}
44
+
45
+	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
46
+	if err != nil {
47
+		res.ResolveErr = err
48
+
49
+		return res
50
+	}
51
+
52
+	res.Resolved = make([]net.IP, 0, len(addrs))
53
+	for _, a := range addrs {
54
+		res.Resolved = append(res.Resolved, a.IP)
55
+	}
56
+
57
+	res.OurIPv4 = conf.PublicIPv4.Get(nil)
58
+	if res.OurIPv4 == nil {
59
+		res.OurIPv4 = getIP(ntw, "tcp4")
60
+	}
61
+
62
+	res.OurIPv6 = conf.PublicIPv6.Get(nil)
63
+	if res.OurIPv6 == nil {
64
+		res.OurIPv6 = getIP(ntw, "tcp6")
65
+	}
66
+
67
+	for _, ip := range res.Resolved {
68
+		if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() {
69
+			res.IPv4Match = true
70
+		}
71
+
72
+		if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() {
73
+			res.IPv6Match = true
74
+		}
75
+	}
76
+
77
+	return res
78
+}

Loading…
取消
儲存