Kaynağa Gözat

Merge pull request #528 from 9seconds/refactor/consolidate-sni-check

internal/cli: consolidate duplicated SNI-DNS check
pull/540/head
Sergei Arkhipov 13 saat önce
ebeveyn
işleme
6a939eef6a
No account linked to committer's email address
3 değiştirilmiş dosya ile 120 ekleme ve 61 silme
  1. 26
    26
      internal/cli/doctor.go
  2. 16
    35
      internal/cli/run_proxy.go
  3. 78
    0
      internal/cli/sni_check.go

+ 26
- 26
internal/cli/doctor.go Dosyayı Görüntüle

@@ -371,26 +371,17 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
371 371
 }
372 372
 
373 373
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
374
-	addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host)
375
-	if err != nil {
374
+	res := runSNICheck(context.Background(), resolver, d.conf, ntw)
375
+
376
+	if res.ResolveErr != nil {
376 377
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
377 378
 			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
378
-			"error":       err,
379
+			"error":       res.ResolveErr,
379 380
 		})
380 381
 		return false
381 382
 	}
382 383
 
383
-	ourIP4 := d.conf.PublicIPv4.Get(nil)
384
-	if ourIP4 == nil {
385
-		ourIP4 = getIP(ntw, "tcp4")
386
-	}
387
-
388
-	ourIP6 := d.conf.PublicIPv6.Get(nil)
389
-	if ourIP6 == nil {
390
-		ourIP6 = getIP(ntw, "tcp6")
391
-	}
392
-
393
-	if ourIP4 == nil && ourIP6 == nil {
384
+	if !res.PublicIPKnown() {
394 385
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
395 386
 			"description": "cannot detect public IP address",
396 387
 			"error":       errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
@@ -398,25 +389,34 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo
398 389
 		return false
399 390
 	}
400 391
 
401
-	strAddresses := []string{}
402
-	for _, value := range addresses {
403
-		if (ourIP4 != nil && value.IP.String() == ourIP4.String()) ||
404
-			(ourIP6 != nil && value.IP.String() == ourIP6.String()) {
405
-			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
406
-				"ip":       value.IP,
407
-				"hostname": d.conf.Secret.Host,
408
-			})
409
-			return true
392
+	if res.IPv4Match || res.IPv6Match {
393
+		var matched net.IP
394
+
395
+		for _, ip := range res.Resolved {
396
+			if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) ||
397
+				(res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) {
398
+				matched = ip
399
+				break
400
+			}
410 401
 		}
411 402
 
412
-		strAddresses = append(strAddresses, `"`+value.IP.String()+`"`)
403
+		tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
404
+			"ip":       matched,
405
+			"hostname": d.conf.Secret.Host,
406
+		})
407
+		return true
408
+	}
409
+
410
+	strAddresses := make([]string, 0, len(res.Resolved))
411
+	for _, ip := range res.Resolved {
412
+		strAddresses = append(strAddresses, `"`+ip.String()+`"`)
413 413
 	}
414 414
 
415 415
 	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
416 416
 		"hostname": d.conf.Secret.Host,
417 417
 		"resolved": strings.Join(strAddresses, ", "),
418
-		"ip4":      ourIP4,
419
-		"ip6":      ourIP6,
418
+		"ip4":      res.OurIPv4,
419
+		"ip6":      res.OurIPv6,
420 420
 	})
421 421
 
422 422
 	return false

+ 16
- 35
internal/cli/run_proxy.go Dosyayı Görüntüle

@@ -215,72 +215,53 @@ func warnSNIMismatch(conf *config.Config, ntw mtglib.Network, log mtglib.Logger)
215 215
 		return
216 216
 	}
217 217
 
218
-	addresses, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
219
-	if err != nil {
218
+	res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw)
219
+
220
+	if res.ResolveErr != nil {
220 221
 		log.BindStr("hostname", host).
221
-			WarningError("SNI-DNS check: cannot resolve secret hostname", err)
222
+			WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr)
222 223
 		return
223 224
 	}
224 225
 
225
-	ourIP4 := conf.PublicIPv4.Get(nil)
226
-	if ourIP4 == nil {
227
-		ourIP4 = getIP(ntw, "tcp4")
228
-	}
229
-
230
-	ourIP6 := conf.PublicIPv6.Get(nil)
231
-	if ourIP6 == nil {
232
-		ourIP6 = getIP(ntw, "tcp6")
233
-	}
234
-
235
-	if ourIP4 == nil && ourIP6 == nil {
226
+	if !res.PublicIPKnown() {
236 227
 		log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
237 228
 		return
238 229
 	}
239 230
 
240
-	v4Match := ourIP4 == nil
241
-	v6Match := ourIP6 == nil
242
-
243
-	for _, addr := range addresses {
244
-		if ourIP4 != nil && addr.IP.String() == ourIP4.String() {
245
-			v4Match = true
246
-		}
247
-
248
-		if ourIP6 != nil && addr.IP.String() == ourIP6.String() {
249
-			v6Match = true
250
-		}
251
-	}
231
+	v4Match := res.OurIPv4 == nil || res.IPv4Match
232
+	v6Match := res.OurIPv6 == nil || res.IPv6Match
252 233
 
253 234
 	if v4Match && v6Match {
254 235
 		return
255 236
 	}
256 237
 
257
-	resolved := make([]string, 0, len(addresses))
258
-	for _, addr := range addresses {
259
-		resolved = append(resolved, addr.IP.String())
238
+	resolved := make([]string, 0, len(res.Resolved))
239
+	for _, ip := range res.Resolved {
240
+		resolved = append(resolved, ip.String())
260 241
 	}
261 242
 
262 243
 	our := ""
263
-	if ourIP4 != nil {
264
-		our = ourIP4.String()
244
+	if res.OurIPv4 != nil {
245
+		our = res.OurIPv4.String()
265 246
 	}
266 247
 
267
-	if ourIP6 != nil {
248
+	if res.OurIPv6 != nil {
268 249
 		if our != "" {
269 250
 			our += "/"
270 251
 		}
271 252
 
272
-		our += ourIP6.String()
253
+		our += res.OurIPv6.String()
273 254
 	}
274 255
 
275 256
 	entry := log.BindStr("hostname", host).
276 257
 		BindStr("resolved", strings.Join(resolved, ", ")).
277 258
 		BindStr("public_ip", our)
278 259
 
279
-	if ourIP4 != nil {
260
+	if res.OurIPv4 != nil {
280 261
 		entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match))
281 262
 	}
282 263
 
283
-	if ourIP6 != nil {
264
+	if res.OurIPv6 != nil {
284 265
 		entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match))
285 266
 	}
286 267
 

+ 78
- 0
internal/cli/sni_check.go Dosyayı Görüntüle

@@ -0,0 +1,78 @@
1
+package cli
2
+
3
+import (
4
+	"context"
5
+	"net"
6
+
7
+	"github.com/9seconds/mtg/v2/internal/config"
8
+	"github.com/9seconds/mtg/v2/mtglib"
9
+)
10
+
11
+// sniCheckResult holds the data gathered while comparing the secret
12
+// hostname's DNS records against this server's public IP addresses.
13
+//
14
+// IPv4Match / IPv6Match report whether a resolved record actually equals the
15
+// corresponding public IP. They are false when that family's public IP could
16
+// not be determined — there is nothing to compare against. Callers decide
17
+// what counts as a clean result from these fields: `mtg doctor` and the
18
+// startup warning apply different rules.
19
+type sniCheckResult struct {
20
+	Resolved   []net.IP
21
+	OurIPv4    net.IP
22
+	OurIPv6    net.IP
23
+	IPv4Match  bool
24
+	IPv6Match  bool
25
+	ResolveErr error
26
+}
27
+
28
+// PublicIPKnown reports whether at least one public IP family was detected.
29
+func (r sniCheckResult) PublicIPKnown() bool {
30
+	return r.OurIPv4 != nil || r.OurIPv6 != nil
31
+}
32
+
33
+// runSNICheck resolves conf.Secret.Host and compares the records with this
34
+// server's public IPv4 and IPv6. Public IPs come from config first and fall
35
+// back to on-the-fly detection via ntw. It gathers data only — it does not
36
+// decide success; see sniCheckResult.
37
+func runSNICheck(
38
+	ctx context.Context,
39
+	resolver *net.Resolver,
40
+	conf *config.Config,
41
+	ntw mtglib.Network,
42
+) sniCheckResult {
43
+	res := sniCheckResult{}
44
+
45
+	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
46
+	if err != nil {
47
+		res.ResolveErr = err
48
+
49
+		return res
50
+	}
51
+
52
+	res.Resolved = make([]net.IP, 0, len(addrs))
53
+	for _, a := range addrs {
54
+		res.Resolved = append(res.Resolved, a.IP)
55
+	}
56
+
57
+	res.OurIPv4 = conf.PublicIPv4.Get(nil)
58
+	if res.OurIPv4 == nil {
59
+		res.OurIPv4 = getIP(ntw, "tcp4")
60
+	}
61
+
62
+	res.OurIPv6 = conf.PublicIPv6.Get(nil)
63
+	if res.OurIPv6 == nil {
64
+		res.OurIPv6 = getIP(ntw, "tcp6")
65
+	}
66
+
67
+	for _, ip := range res.Resolved {
68
+		if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() {
69
+			res.IPv4Match = true
70
+		}
71
+
72
+		if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() {
73
+			res.IPv6Match = true
74
+		}
75
+	}
76
+
77
+	return res
78
+}

Loading…
İptal
Kaydet