Browse Source

Fix SNI check failing when one IP family is undetectable

runSNICheck wired each family's getIP failure through a shared
context.WithCancelCause, so a single family's detection failure (for
example tcp6 on an IPv4-only-egress server) made the whole check return
an error even when the other family was detected and matched. Both
callers treat that error as fatal, so a server that is fine on IPv4
failed the SNI check outright -- the exact audience of #529.

Mirror the graceful per-family handling access.go already uses: discard
the per-family getIP error and report an undetectable family through an
empty OurIP4/OurIP6, which both callers already surface via their
"cannot detect public IP address" branch. The error return is now
reserved for genuine DNS-resolution failure. Removing the shared cancel
also makes the two families independent, so a fast-failing family can no
longer abort the other family's in-flight detection.

Add a regression test that drives the real runSNICheck over a loopback
DNS fake and an IPv4-only-egress network fake.
pull/557/head
Alexey Dolotov 1 week ago
parent
commit
8cf62d7375
2 changed files with 199 additions and 16 deletions
  1. 3
    16
      internal/cli/sni_check.go
  2. 196
    0
      internal/cli/sni_check_test.go

+ 3
- 16
internal/cli/sni_check.go View File

@@ -25,9 +25,6 @@ func runSNICheck(
25 25
 ) (sniCheckResult, error) {
26 26
 	res := sniCheckResult{}
27 27
 
28
-	ctx, cancelCause := context.WithCancelCause(ctx)
29
-	defer cancelCause(nil)
30
-
31 28
 	addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
32 29
 	if err != nil {
33 30
 		return res, fmt.Errorf("cannot resolve addresses of %s: %w", conf.Secret.Host, err)
@@ -49,14 +46,9 @@ func runSNICheck(
49 46
 
50 47
 	if len(res.ResolvedIP4) > 0 {
51 48
 		wg.Go(func() {
52
-			var err error
53
-
54 49
 			ip := conf.PublicIPv4.Get(nil)
55 50
 			if ip == nil {
56
-				ip, err = getIP(ctx, ntw, "tcp4")
57
-				if err != nil {
58
-					cancelCause(err)
59
-				}
51
+				ip, _ = getIP(ctx, ntw, "tcp4")
60 52
 			}
61 53
 
62 54
 			if ip != nil {
@@ -67,14 +59,9 @@ func runSNICheck(
67 59
 
68 60
 	if len(res.ResolvedIP6) > 0 {
69 61
 		wg.Go(func() {
70
-			var err error
71
-
72 62
 			ip := conf.PublicIPv6.Get(nil)
73 63
 			if ip == nil {
74
-				ip, err = getIP(ctx, ntw, "tcp6")
75
-				if err != nil {
76
-					cancelCause(err)
77
-				}
64
+				ip, _ = getIP(ctx, ntw, "tcp6")
78 65
 			}
79 66
 
80 67
 			if ip != nil {
@@ -85,5 +72,5 @@ func runSNICheck(
85 72
 
86 73
 	wg.Wait()
87 74
 
88
-	return res, context.Cause(ctx)
75
+	return res, nil
89 76
 }

+ 196
- 0
internal/cli/sni_check_test.go View File

@@ -0,0 +1,196 @@
1
+package cli
2
+
3
+import (
4
+	"context"
5
+	"io"
6
+	"net"
7
+	"net/http"
8
+	"strings"
9
+	"testing"
10
+
11
+	"github.com/9seconds/mtg/v2/essentials"
12
+	"github.com/9seconds/mtg/v2/internal/config"
13
+	"github.com/9seconds/mtg/v2/mtglib"
14
+	"github.com/stretchr/testify/require"
15
+	"golang.org/x/net/dns/dnsmessage"
16
+)
17
+
18
+// startSNITestDNS spins up a loopback UDP resolver that answers every query
19
+// with the given A and AAAA records, so runSNICheck sees a dual-stack secret
20
+// host without touching the real network. It returns a *net.Resolver wired to
21
+// it.
22
+func startSNITestDNS(t *testing.T, a, aaaa net.IP) *net.Resolver {
23
+	t.Helper()
24
+
25
+	pc, err := net.ListenPacket("udp", "127.0.0.1:0")
26
+	require.NoError(t, err)
27
+	t.Cleanup(func() { pc.Close() }) //nolint: errcheck
28
+
29
+	go func() {
30
+		buf := make([]byte, 512)
31
+
32
+		for {
33
+			n, addr, err := pc.ReadFrom(buf)
34
+			if err != nil {
35
+				return
36
+			}
37
+
38
+			var parser dnsmessage.Parser
39
+
40
+			hdr, err := parser.Start(buf[:n])
41
+			if err != nil {
42
+				continue
43
+			}
44
+
45
+			question, err := parser.Question()
46
+			if err != nil {
47
+				continue
48
+			}
49
+
50
+			builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{
51
+				ID:                 hdr.ID,
52
+				Response:           true,
53
+				RecursionAvailable: true,
54
+			})
55
+			builder.EnableCompression()
56
+			_ = builder.StartQuestions()
57
+			_ = builder.Question(question)
58
+			_ = builder.StartAnswers()
59
+
60
+			rh := dnsmessage.ResourceHeader{
61
+				Name:  question.Name,
62
+				Class: dnsmessage.ClassINET,
63
+				TTL:   60,
64
+			}
65
+
66
+			switch question.Type {
67
+			case dnsmessage.TypeA:
68
+				rh.Type = dnsmessage.TypeA
69
+				var v4 [4]byte
70
+				copy(v4[:], a.To4())
71
+				_ = builder.AResource(rh, dnsmessage.AResource{A: v4})
72
+			case dnsmessage.TypeAAAA:
73
+				rh.Type = dnsmessage.TypeAAAA
74
+				var v6 [16]byte
75
+				copy(v6[:], aaaa.To16())
76
+				_ = builder.AAAAResource(rh, dnsmessage.AAAAResource{AAAA: v6})
77
+			}
78
+
79
+			msg, err := builder.Finish()
80
+			if err != nil {
81
+				continue
82
+			}
83
+
84
+			pc.WriteTo(msg, addr) //nolint: errcheck
85
+		}
86
+	}()
87
+
88
+	dnsAddr := pc.LocalAddr().String()
89
+
90
+	return &net.Resolver{
91
+		PreferGo: true,
92
+		Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
93
+			var d net.Dialer
94
+
95
+			return d.DialContext(ctx, "udp", dnsAddr)
96
+		},
97
+	}
98
+}
99
+
100
+// ipv4OnlyEgressNetwork fakes mtglib.Network so that public-IP detection
101
+// succeeds over tcp4 and fails over tcp6 — the classic IPv4-only-egress
102
+// server. getIP's per-protocol dial is routed at a loopback listener: a tcp4
103
+// dial to 127.0.0.1 connects, a tcp6 dial to the same address fails ("no
104
+// suitable address"), so we exercise the real getIP code path without the
105
+// internet.
106
+type ipv4OnlyEgressNetwork struct {
107
+	listenerAddr string
108
+	detectedV4   string
109
+}
110
+
111
+func (n *ipv4OnlyEgressNetwork) Dial(_, _ string) (essentials.Conn, error) {
112
+	panic("unused")
113
+}
114
+
115
+func (n *ipv4OnlyEgressNetwork) DialContext(_ context.Context, _, _ string) (essentials.Conn, error) {
116
+	panic("unused")
117
+}
118
+
119
+func (n *ipv4OnlyEgressNetwork) NativeDialer() *net.Dialer {
120
+	return &net.Dialer{}
121
+}
122
+
123
+func (n *ipv4OnlyEgressNetwork) MakeHTTPClient(
124
+	dialFunc func(ctx context.Context, network, address string) (essentials.Conn, error),
125
+) *http.Client {
126
+	return &http.Client{
127
+		Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
128
+			conn, err := dialFunc(req.Context(), "tcp", n.listenerAddr)
129
+			if err != nil {
130
+				return nil, err
131
+			}
132
+
133
+			conn.Close() //nolint: errcheck
134
+
135
+			return &http.Response{
136
+				StatusCode: http.StatusOK,
137
+				Body:       io.NopCloser(strings.NewReader(n.detectedV4)),
138
+				Header:     make(http.Header),
139
+			}, nil
140
+		}),
141
+	}
142
+}
143
+
144
+type roundTripFunc func(*http.Request) (*http.Response, error)
145
+
146
+func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
147
+	return f(r)
148
+}
149
+
150
+// TestRunSNICheckIPv4OnlyEgressGraceful reproduces the #529/#542 regression:
151
+// a dual-stack secret host on a server whose IPv6 egress is down. The tcp6
152
+// public-IP detection fails, but the tcp4 detection succeeds and matches the
153
+// host's A record, so the SNI check must NOT report a hard error — one
154
+// family being undetectable is graceful degradation, not failure.
155
+func TestRunSNICheckIPv4OnlyEgressGraceful(t *testing.T) {
156
+	const ourV4 = "192.0.2.4" // RFC 5737 TEST-NET-1
157
+
158
+	resolver := startSNITestDNS(t, net.ParseIP(ourV4), net.ParseIP("2001:db8::1")) // RFC 3849 doc range
159
+
160
+	// Loopback target for getIP's dial. Keep it the IPv4 literal 127.0.0.1: a
161
+	// "tcp6" dial to it fails deterministically ("no suitable address") on any
162
+	// host regardless of IPv6 connectivity, which is what makes tcp6 detection
163
+	// fail here. Do not replace with a real ::1 setup — that reintroduces flake.
164
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
165
+	require.NoError(t, err)
166
+	t.Cleanup(func() { listener.Close() }) //nolint: errcheck
167
+
168
+	go func() {
169
+		for {
170
+			conn, err := listener.Accept()
171
+			if err != nil {
172
+				return
173
+			}
174
+
175
+			conn.Close() //nolint: errcheck
176
+		}
177
+	}()
178
+
179
+	ntw := &ipv4OnlyEgressNetwork{
180
+		listenerAddr: listener.Addr().String(),
181
+		detectedV4:   ourV4,
182
+	}
183
+
184
+	conf := &config.Config{}
185
+	conf.Secret.Host = "secret-host.test"
186
+
187
+	res, err := runSNICheck(context.Background(), conf, resolver, ntw)
188
+
189
+	// The load-bearing assertion: a single family's detection failure must not
190
+	// poison the whole result. Before the fix this returns a non-nil error.
191
+	require.NoError(t, err)
192
+	require.Equal(t, ourV4, res.OurIP4, "IPv4 public IP should be detected and match the A record")
193
+	require.Empty(t, res.OurIP6, "IPv6 is undetectable on IPv4-only egress; must degrade, not error")
194
+}
195
+
196
+var _ mtglib.Network = (*ipv4OnlyEgressNetwork)(nil)

Loading…
Cancel
Save