|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+package cli
|
|
|
2
|
+
|
|
|
3
|
+import (
|
|
|
4
|
+ "context"
|
|
|
5
|
+ "crypto/ecdsa"
|
|
|
6
|
+ "crypto/elliptic"
|
|
|
7
|
+ "crypto/rand"
|
|
|
8
|
+ "crypto/tls"
|
|
|
9
|
+ "crypto/x509"
|
|
|
10
|
+ "crypto/x509/pkix"
|
|
|
11
|
+ "math/big"
|
|
|
12
|
+ "net"
|
|
|
13
|
+ "strings"
|
|
|
14
|
+ "testing"
|
|
|
15
|
+ "time"
|
|
|
16
|
+)
|
|
|
17
|
+
|
|
|
18
|
+// makeCert builds a self-signed leaf certificate valid for the supplied DNS
|
|
|
19
|
+// name (and IP, so dialing 127.0.0.1 still reaches the listener) plus a
|
|
|
20
|
+// matching tls.Config and an x509 pool that trusts it.
|
|
|
21
|
+func makeCert(t *testing.T, dnsName string, notAfter time.Time) (tls.Certificate, *x509.CertPool) {
|
|
|
22
|
+ t.Helper()
|
|
|
23
|
+
|
|
|
24
|
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
25
|
+ if err != nil {
|
|
|
26
|
+ t.Fatalf("generate key: %v", err)
|
|
|
27
|
+ }
|
|
|
28
|
+
|
|
|
29
|
+ tmpl := &x509.Certificate{
|
|
|
30
|
+ SerialNumber: big.NewInt(1),
|
|
|
31
|
+ Subject: pkix.Name{CommonName: dnsName},
|
|
|
32
|
+ NotBefore: time.Now().Add(-time.Hour),
|
|
|
33
|
+ NotAfter: notAfter,
|
|
|
34
|
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
|
35
|
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
|
36
|
+ BasicConstraintsValid: true,
|
|
|
37
|
+ IsCA: true,
|
|
|
38
|
+ DNSNames: []string{dnsName},
|
|
|
39
|
+ IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
|
|
40
|
+ }
|
|
|
41
|
+
|
|
|
42
|
+ der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
|
|
43
|
+ if err != nil {
|
|
|
44
|
+ t.Fatalf("create certificate: %v", err)
|
|
|
45
|
+ }
|
|
|
46
|
+
|
|
|
47
|
+ leaf, err := x509.ParseCertificate(der)
|
|
|
48
|
+ if err != nil {
|
|
|
49
|
+ t.Fatalf("parse certificate: %v", err)
|
|
|
50
|
+ }
|
|
|
51
|
+
|
|
|
52
|
+ pool := x509.NewCertPool()
|
|
|
53
|
+ pool.AddCert(leaf)
|
|
|
54
|
+
|
|
|
55
|
+ return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}, pool
|
|
|
56
|
+}
|
|
|
57
|
+
|
|
|
58
|
+// startTLSServer spins up a TLS listener that completes handshakes using cert
|
|
|
59
|
+// and returns its address. It is closed when the test finishes.
|
|
|
60
|
+func startTLSServer(t *testing.T, cert tls.Certificate) string {
|
|
|
61
|
+ t.Helper()
|
|
|
62
|
+
|
|
|
63
|
+ ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
|
|
64
|
+ Certificates: []tls.Certificate{cert},
|
|
|
65
|
+ MinVersion: tls.VersionTLS12,
|
|
|
66
|
+ })
|
|
|
67
|
+ if err != nil {
|
|
|
68
|
+ t.Fatalf("listen: %v", err)
|
|
|
69
|
+ }
|
|
|
70
|
+
|
|
|
71
|
+ t.Cleanup(func() { _ = ln.Close() })
|
|
|
72
|
+
|
|
|
73
|
+ go func() {
|
|
|
74
|
+ for {
|
|
|
75
|
+ conn, err := ln.Accept()
|
|
|
76
|
+ if err != nil {
|
|
|
77
|
+ return
|
|
|
78
|
+ }
|
|
|
79
|
+
|
|
|
80
|
+ go func() {
|
|
|
81
|
+ // Drive the handshake so the client side completes, then drop.
|
|
|
82
|
+ if tc, ok := conn.(*tls.Conn); ok {
|
|
|
83
|
+ _ = tc.Handshake()
|
|
|
84
|
+ }
|
|
|
85
|
+ _ = conn.Close()
|
|
|
86
|
+ }()
|
|
|
87
|
+ }
|
|
|
88
|
+ }()
|
|
|
89
|
+
|
|
|
90
|
+ return ln.Addr().String()
|
|
|
91
|
+}
|
|
|
92
|
+
|
|
|
93
|
+func TestProbeFrontingTLS_ValidCert(t *testing.T) {
|
|
|
94
|
+ cert, pool := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour))
|
|
|
95
|
+ addr := startTLSServer(t, cert)
|
|
|
96
|
+
|
|
|
97
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
98
|
+ defer cancel()
|
|
|
99
|
+
|
|
|
100
|
+ err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", pool)
|
|
|
101
|
+ if err != nil {
|
|
|
102
|
+ t.Fatalf("expected success, got error: %v", err)
|
|
|
103
|
+ }
|
|
|
104
|
+}
|
|
|
105
|
+
|
|
|
106
|
+func TestProbeFrontingTLS_WrongHost(t *testing.T) {
|
|
|
107
|
+ // Cert is for front.example.org, but we verify against other.example.org.
|
|
|
108
|
+ cert, pool := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour))
|
|
|
109
|
+ addr := startTLSServer(t, cert)
|
|
|
110
|
+
|
|
|
111
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
112
|
+ defer cancel()
|
|
|
113
|
+
|
|
|
114
|
+ err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "other.example.org", pool)
|
|
|
115
|
+ if err == nil {
|
|
|
116
|
+ t.Fatal("expected SAN-mismatch failure, got success")
|
|
|
117
|
+ }
|
|
|
118
|
+ if !strings.Contains(err.Error(), "x509") {
|
|
|
119
|
+ t.Fatalf("expected x509 verification error, got: %v", err)
|
|
|
120
|
+ }
|
|
|
121
|
+}
|
|
|
122
|
+
|
|
|
123
|
+func TestProbeFrontingTLS_UntrustedCA(t *testing.T) {
|
|
|
124
|
+ // Server cert is self-signed; we hand the client an empty pool that does
|
|
|
125
|
+ // not trust it. Default verification must reject.
|
|
|
126
|
+ cert, _ := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour))
|
|
|
127
|
+ addr := startTLSServer(t, cert)
|
|
|
128
|
+
|
|
|
129
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
130
|
+ defer cancel()
|
|
|
131
|
+
|
|
|
132
|
+ err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", x509.NewCertPool())
|
|
|
133
|
+ if err == nil {
|
|
|
134
|
+ t.Fatal("expected untrusted-CA failure, got success")
|
|
|
135
|
+ }
|
|
|
136
|
+ if !strings.Contains(err.Error(), "x509") {
|
|
|
137
|
+ t.Fatalf("expected x509 verification error, got: %v", err)
|
|
|
138
|
+ }
|
|
|
139
|
+}
|
|
|
140
|
+
|
|
|
141
|
+func TestProbeFrontingTLS_ExpiredCert(t *testing.T) {
|
|
|
142
|
+ // Same trust anchor, but the cert is already expired.
|
|
|
143
|
+ cert, pool := makeCert(t, "front.example.org", time.Now().Add(-time.Hour))
|
|
|
144
|
+ addr := startTLSServer(t, cert)
|
|
|
145
|
+
|
|
|
146
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
147
|
+ defer cancel()
|
|
|
148
|
+
|
|
|
149
|
+ err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", pool)
|
|
|
150
|
+ if err == nil {
|
|
|
151
|
+ t.Fatal("expected expiry failure, got success")
|
|
|
152
|
+ }
|
|
|
153
|
+ if !strings.Contains(err.Error(), "expired") {
|
|
|
154
|
+ t.Fatalf("expected expiry error, got: %v", err)
|
|
|
155
|
+ }
|
|
|
156
|
+}
|
|
|
157
|
+
|
|
|
158
|
+func TestProbeFrontingTLS_OverrideDialDifferentFromSNI(t *testing.T) {
|
|
|
159
|
+ // Domain-fronting override: dial one address (the listener bound to
|
|
|
160
|
+ // 127.0.0.1), but verify the cert against the secret host name. The cert
|
|
|
161
|
+ // is issued for the secret host, so verification must pass even though the
|
|
|
162
|
+ // dial target is a bare IP:port.
|
|
|
163
|
+ cert, pool := makeCert(t, "secret.example.org", time.Now().Add(24*time.Hour))
|
|
|
164
|
+ addr := startTLSServer(t, cert) // e.g. 127.0.0.1:NNNNN
|
|
|
165
|
+
|
|
|
166
|
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
167
|
+ defer cancel()
|
|
|
168
|
+
|
|
|
169
|
+ // dial-target (addr, an IP:port) != SNI (secret.example.org)
|
|
|
170
|
+ err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "secret.example.org", pool)
|
|
|
171
|
+ if err != nil {
|
|
|
172
|
+ t.Fatalf("expected success when dialing override addr with secret-host SNI, got: %v", err)
|
|
|
173
|
+ }
|
|
|
174
|
+}
|