Просмотр исходного кода

Add multi-secret support and per-user stats API

Support multiple named secrets in [secrets] config section. During
FakeTLS handshake each secret is tried until HMAC validates. Matched
secret name is logged and used for per-user statistics tracking.

Built-in HTTP stats API (configurable via api-bind-to) exposes
GET /stats with per-user connection counts, bytes in/out, and
last-seen timestamps.

Backward compatible: single "secret" config key still works.
pull/434/head
Alexey Dolotov 1 месяц назад
Родитель
Сommit
1a450e3c45

+ 12
- 0
example.config.toml Просмотреть файл

@@ -20,6 +20,18 @@ debug = true
20 20
 # should either be base64-encoded or starts with ee.
21 21
 secret = "ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d"
22 22
 
23
+# For multi-user support, use the [secrets] section instead of "secret".
24
+# Each key is a user name, used for per-user stats tracking.
25
+# All secrets must use the same hostname.
26
+# [secrets]
27
+# alice = "ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d"
28
+# bob = "ee0123456789abcdef0123456789abcd9573746f726167652e676f6f676c65617069732e636f6d"
29
+
30
+# Host:port pair to bind the built-in stats HTTP API server.
31
+# GET /stats returns per-user connection counts and traffic.
32
+# If not set, the stats API is not started.
33
+# api-bind-to = "127.0.0.1:9090"
34
+
23 35
 # Host:port pair to run proxy on.
24 36
 bind-to = "0.0.0.0:3128"
25 37
 

+ 56
- 10
internal/cli/access.go Просмотреть файл

@@ -6,20 +6,25 @@ import (
6 6
 	"net"
7 7
 	"net/url"
8 8
 	"os"
9
+	"sort"
9 10
 	"strconv"
10 11
 	"sync"
11 12
 
12 13
 	"github.com/9seconds/mtg/v2/internal/config"
13 14
 	"github.com/9seconds/mtg/v2/internal/utils"
15
+	"github.com/9seconds/mtg/v2/mtglib"
14 16
 )
15 17
 
18
+type accessResponseSecret struct {
19
+	Hex    string `json:"hex"`
20
+	Base64 string `json:"base64"`
21
+}
22
+
16 23
 type accessResponse struct {
17
-	IPv4   *accessResponseURLs `json:"ipv4,omitempty"`
18
-	IPv6   *accessResponseURLs `json:"ipv6,omitempty"`
19
-	Secret struct {
20
-		Hex    string `json:"hex"`
21
-		Base64 string `json:"base64"`
22
-	} `json:"secret"`
24
+	IPv4    *accessResponseURLs             `json:"ipv4,omitempty"`
25
+	IPv6    *accessResponseURLs             `json:"ipv6,omitempty"`
26
+	Secret  accessResponseSecret            `json:"secret"`
27
+	Secrets map[string]accessResponseSecret `json:"secrets,omitempty"`
23 28
 }
24 29
 
25 30
 type accessResponseURLs struct {
@@ -46,8 +51,34 @@ func (a *Access) Run(cli *CLI, version string) error {
46 51
 	}
47 52
 
48 53
 	resp := &accessResponse{}
49
-	resp.Secret.Base64 = conf.Secret.Base64()
50
-	resp.Secret.Hex = conf.Secret.Hex()
54
+	secrets := conf.GetSecrets()
55
+
56
+	// Sort secret names for deterministic "first secret" selection.
57
+	sortedNames := make([]string, 0, len(secrets))
58
+	for name := range secrets {
59
+		sortedNames = append(sortedNames, name)
60
+	}
61
+
62
+	sort.Strings(sortedNames)
63
+
64
+	// For backward compatibility, populate the single Secret field with the
65
+	// first secret (sorted alphabetically).
66
+	if len(sortedNames) > 0 {
67
+		first := secrets[sortedNames[0]]
68
+		resp.Secret.Base64 = first.Base64()
69
+		resp.Secret.Hex = first.Hex()
70
+	}
71
+
72
+	if len(secrets) > 1 {
73
+		resp.Secrets = make(map[string]accessResponseSecret, len(secrets))
74
+
75
+		for name, s := range secrets {
76
+			resp.Secrets[name] = accessResponseSecret{
77
+				Hex:    s.Hex(),
78
+				Base64: s.Base64(),
79
+			}
80
+		}
81
+	}
51 82
 
52 83
 	ntw, err := makeNetwork(conf, version)
53 84
 	if err != nil {
@@ -114,10 +145,25 @@ func (a *Access) makeURLs(conf *config.Config, ip net.IP) *accessResponseURLs {
114 145
 	values.Set("server", ip.String())
115 146
 	values.Set("port", strconv.Itoa(int(portNo)))
116 147
 
148
+	// Use the first available secret (sorted) for URL generation.
149
+	secrets := conf.GetSecrets()
150
+	names := make([]string, 0, len(secrets))
151
+
152
+	for name := range secrets {
153
+		names = append(names, name)
154
+	}
155
+
156
+	sort.Strings(names)
157
+
158
+	var firstSecret mtglib.Secret
159
+	if len(names) > 0 {
160
+		firstSecret = secrets[names[0]]
161
+	}
162
+
117 163
 	if a.Hex {
118
-		values.Set("secret", conf.Secret.Hex())
164
+		values.Set("secret", firstSecret.Hex())
119 165
 	} else {
120
-		values.Set("secret", conf.Secret.Base64())
166
+		values.Set("secret", firstSecret.Base64())
121 167
 	}
122 168
 
123 169
 	urlQuery := values.Encode()

+ 13
- 5
internal/cli/doctor.go Просмотреть файл

@@ -290,8 +290,16 @@ func (d *Doctor) checkNetworkAddresses(ntw mtglib.Network, addresses []string) e
290 290
 	return err
291 291
 }
292 292
 
293
+func (d *Doctor) getFirstSecretHost() string {
294
+	for _, s := range d.conf.GetSecrets() {
295
+		return s.Host
296
+	}
297
+
298
+	return ""
299
+}
300
+
293 301
 func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
294
-	host := d.conf.Secret.Host
302
+	host := d.getFirstSecretHost()
295 303
 	if ip := d.conf.GetDomainFrontingIP(nil); ip != "" {
296 304
 		host = ip
297 305
 	}
@@ -323,10 +331,10 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
323 331
 }
324 332
 
325 333
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
326
-	addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host)
334
+	addresses, err := resolver.LookupIPAddr(context.Background(), d.getFirstSecretHost())
327 335
 	if err != nil {
328 336
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
329
-			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
337
+			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.getFirstSecretHost()),
330 338
 			"error":       err,
331 339
 		})
332 340
 		return false
@@ -356,7 +364,7 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo
356 364
 			(ourIP6 != nil && value.IP.String() == ourIP6.String()) {
357 365
 			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
358 366
 				"ip":       value.IP,
359
-				"hostname": d.conf.Secret.Host,
367
+				"hostname": d.getFirstSecretHost(),
360 368
 			})
361 369
 			return true
362 370
 		}
@@ -365,7 +373,7 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo
365 373
 	}
366 374
 
367 375
 	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
368
-		"hostname": d.conf.Secret.Host,
376
+		"hostname": d.getFirstSecretHost(),
369 377
 		"resolved": strings.Join(strAddresses, ", "),
370 378
 		"ip4":      ourIP4,
371 379
 		"ip6":      ourIP6,

+ 3
- 1
internal/cli/run_proxy.go Просмотреть файл

@@ -253,7 +253,7 @@ func runProxy(conf *config.Config, version string) error { //nolint: funlen
253 253
 		IPAllowlist:     allowlist,
254 254
 		EventStream:     eventStream,
255 255
 
256
-		Secret:                      conf.Secret,
256
+		Secrets:                     conf.GetSecrets(),
257 257
 		Concurrency:                 conf.GetConcurrency(mtglib.DefaultConcurrency),
258 258
 		DomainFrontingPort:          conf.GetDomainFrontingPort(mtglib.DefaultDomainFrontingPort),
259 259
 		DomainFrontingIP:            conf.GetDomainFrontingIP(nil),
@@ -269,6 +269,8 @@ func runProxy(conf *config.Config, version string) error { //nolint: funlen
269 269
 		DoppelGangerPerRaid: conf.Defense.Doppelganger.Repeats.Get(mtglib.DoppelGangerPerRaid),
270 270
 		DoppelGangerEach:    conf.Defense.Doppelganger.UpdateEach.Get(mtglib.DoppelGangerEach),
271 271
 		DoppelGangerDRS:     conf.Defense.Doppelganger.DRS.Get(false),
272
+
273
+		APIBindTo: conf.APIBindTo.Get(""),
272 274
 	}
273 275
 
274 276
 	proxy, err := mtglib.NewProxy(opts)

+ 27
- 7
internal/config/config.go Просмотреть файл

@@ -23,10 +23,11 @@ type ListConfig struct {
23 23
 }
24 24
 
25 25
 type Config struct {
26
-	Debug                       TypeBool        `json:"debug"`
27
-	AllowFallbackOnUnknownDC    TypeBool        `json:"allowFallbackOnUnknownDc"`
28
-	Secret                      mtglib.Secret   `json:"secret"`
29
-	BindTo                      TypeHostPort    `json:"bindTo"`
26
+	Debug                       TypeBool                   `json:"debug"`
27
+	AllowFallbackOnUnknownDC    TypeBool                   `json:"allowFallbackOnUnknownDc"`
28
+	Secret                      mtglib.Secret              `json:"secret"`
29
+	Secrets                     map[string]mtglib.Secret   `json:"secrets"`
30
+	BindTo                      TypeHostPort               `json:"bindTo"`
30 31
 	ProxyProtocolListener       TypeBool        `json:"proxyProtocolListener"`
31 32
 	PreferIP                    TypePreferIP    `json:"preferIp"`
32 33
 	AutoUpdate                  TypeBool        `json:"autoUpdate"`
@@ -68,7 +69,8 @@ type Config struct {
68 69
 		DNS     TypeDNSURI     `json:"dns"`
69 70
 		Proxies []TypeProxyURL `json:"proxies"`
70 71
 	} `json:"network"`
71
-	Stats struct {
72
+	APIBindTo TypeHostPort `json:"apiBindTo"`
73
+	Stats     struct {
72 74
 		StatsD struct {
73 75
 			Optional
74 76
 
@@ -125,8 +127,16 @@ func (c *Config) GetDomainFrontingProxyProtocol(defaultValue bool) bool {
125 127
 }
126 128
 
127 129
 func (c *Config) Validate() error {
128
-	if !c.Secret.Valid() {
129
-		return fmt.Errorf("invalid secret %s", c.Secret.String())
130
+	if len(c.Secrets) == 0 {
131
+		if !c.Secret.Valid() {
132
+			return fmt.Errorf("invalid secret %s", c.Secret.String())
133
+		}
134
+	} else {
135
+		for name, s := range c.Secrets {
136
+			if !s.Valid() {
137
+				return fmt.Errorf("invalid secret %q: %s", name, s.String())
138
+			}
139
+		}
130 140
 	}
131 141
 
132 142
 	if c.BindTo.Get("") == "" {
@@ -136,6 +146,16 @@ func (c *Config) Validate() error {
136 146
 	return nil
137 147
 }
138 148
 
149
+// GetSecrets returns all secrets as a map. If the new [secrets] section is used,
150
+// returns that map. Otherwise, wraps the single Secret as {"default": Secret}.
151
+func (c *Config) GetSecrets() map[string]mtglib.Secret {
152
+	if len(c.Secrets) > 0 {
153
+		return c.Secrets
154
+	}
155
+
156
+	return map[string]mtglib.Secret{"default": c.Secret}
157
+}
158
+
139 159
 func (c *Config) String() string {
140 160
 	buf := &bytes.Buffer{}
141 161
 	encoder := json.NewEncoder(buf)

+ 5
- 3
internal/config/parse.go Просмотреть файл

@@ -11,8 +11,9 @@ import (
11 11
 type tomlConfig struct {
12 12
 	Debug                       bool   `toml:"debug" json:"debug,omitempty"`
13 13
 	AllowFallbackOnUnknownDC    bool   `toml:"allow-fallback-on-unknown-dc" json:"allowFallbackOnUnknownDc,omitempty"`
14
-	Secret                      string `toml:"secret" json:"secret"`
15
-	BindTo                      string `toml:"bind-to" json:"bindTo"`
14
+	Secret                      string            `toml:"secret" json:"secret,omitempty"`
15
+	Secrets                     map[string]string `toml:"secrets" json:"secrets,omitempty"`
16
+	BindTo                      string            `toml:"bind-to" json:"bindTo"`
16 17
 	ProxyProtocolListener       bool   `toml:"proxy-protocol-listener" json:"proxyProtocolListener"`
17 18
 	PreferIP                    string `toml:"prefer-ip" json:"preferIp,omitempty"`
18 19
 	AutoUpdate                  bool   `toml:"auto-update" json:"autoUpdate,omitempty"`
@@ -63,7 +64,8 @@ type tomlConfig struct {
63 64
 		DNS     string   `toml:"dns" json:"dns,omitempty"`
64 65
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`
65 66
 	} `toml:"network" json:"network,omitempty"`
66
-	Stats struct {
67
+	APIBindTo string `toml:"api-bind-to" json:"apiBindTo,omitempty"`
68
+	Stats     struct {
67 69
 		StatsD struct {
68 70
 			Enabled      bool   `toml:"enabled" json:"enabled,omitempty"`
69 71
 			Address      string `toml:"address" json:"address,omitempty"`

+ 35
- 0
mtglib/counting_conn.go Просмотреть файл

@@ -0,0 +1,35 @@
1
+package mtglib
2
+
3
+import (
4
+	"github.com/9seconds/mtg/v2/essentials"
5
+)
6
+
7
+// countingConn wraps essentials.Conn and counts bytes through a cached
8
+// *secretStats pointer. The pointer is resolved once at construction time
9
+// so Read/Write never need to acquire a lock.
10
+type countingConn struct {
11
+	essentials.Conn
12
+	st *secretStats
13
+}
14
+
15
+func newCountingConn(conn essentials.Conn, stats *ProxyStats, secretName string) *countingConn {
16
+	return &countingConn{Conn: conn, st: stats.getOrCreate(secretName)}
17
+}
18
+
19
+func (c *countingConn) Read(p []byte) (int, error) {
20
+	n, err := c.Conn.Read(p)
21
+	if n > 0 {
22
+		c.st.bytesIn.Add(int64(n))
23
+	}
24
+
25
+	return n, err
26
+}
27
+
28
+func (c *countingConn) Write(p []byte) (int, error) {
29
+	n, err := c.Conn.Write(p)
30
+	if n > 0 {
31
+		c.st.bytesOut.Add(int64(n))
32
+	}
33
+
34
+	return n, err
35
+}

+ 52
- 27
mtglib/internal/tls/fake/client_side.go Просмотреть файл

@@ -36,26 +36,41 @@ type ClientHello struct {
36 36
 	CipherSuite uint16
37 37
 }
38 38
 
39
+// ReadClientHelloResult contains the parsed ClientHello and the index of the
40
+// secret that matched the HMAC validation.
41
+type ReadClientHelloResult struct {
42
+	Hello        *ClientHello
43
+	MatchedIndex int
44
+}
45
+
39 46
 func ReadClientHello(
40 47
 	conn net.Conn,
41 48
 	secret []byte,
42 49
 	hostname string,
43 50
 	tolerateTimeSkewness time.Duration,
44 51
 ) (*ClientHello, error) {
52
+	result, err := ReadClientHelloMulti(conn, [][]byte{secret}, hostname, tolerateTimeSkewness)
53
+	if err != nil {
54
+		return nil, err
55
+	}
56
+
57
+	return result.Hello, nil
58
+}
59
+
60
+// ReadClientHelloMulti is like ReadClientHello but accepts multiple secrets.
61
+// It tries each secret until one validates the HMAC. On success it returns
62
+// the ClientHello and the index of the matched secret.
63
+func ReadClientHelloMulti(
64
+	conn net.Conn,
65
+	secrets [][]byte,
66
+	hostname string,
67
+	tolerateTimeSkewness time.Duration,
68
+) (*ReadClientHelloResult, error) {
45 69
 	if err := conn.SetReadDeadline(time.Now().Add(ClientHelloReadTimeout)); err != nil {
46 70
 		return nil, fmt.Errorf("cannot set read deadline: %w", err)
47 71
 	}
48 72
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
49 73
 
50
-	// This is how FakeTLS is organized:
51
-	//  1. We create sha256 HMAC with a given secret
52
-	//  2. We dump there a whole TLS frame except of the fact that random
53
-	//     is filled with all zeroes
54
-	//  3. Digest is computed. This digest should be XORed with
55
-	//     original client random
56
-	//  4. New digest should be all 0 except of last 4 bytes
57
-	//  5. Last 4 bytes are little endian uint32 of UNIX timestamp when
58
-	//     this message was created.
59 74
 	handshakeCopyBuf := &bytes.Buffer{}
60 75
 	reader := io.TeeReader(conn, handshakeCopyBuf)
61 76
 
@@ -83,31 +98,41 @@ func ReadClientHello(
83 98
 		return nil, fmt.Errorf("cannot find %s in %v", hostname, sniHostnames)
84 99
 	}
85 100
 
86
-	digest := hmac.New(sha256.New, secret)
87
-	// we write a copy of the handshake with client random all nullified.
88
-	digest.Write(handshakeCopyBuf.Next(RandomOffset))
89
-	handshakeCopyBuf.Next(RandomLen)
90
-	digest.Write(emptyRandom[:])
91
-	digest.Write(handshakeCopyBuf.Bytes())
101
+	// Save the handshake bytes so we can reuse them for each secret attempt.
102
+	handshakeBytes := handshakeCopyBuf.Bytes()
92 103
 
93
-	computed := digest.Sum(nil)
104
+	for idx, secret := range secrets {
105
+		digest := hmac.New(sha256.New, secret)
94 106
 
95
-	for i := range RandomLen {
96
-		computed[i] ^= hello.Random[i]
97
-	}
107
+		// Write the handshake with client random all nullified.
108
+		digest.Write(handshakeBytes[:RandomOffset])
109
+		digest.Write(emptyRandom[:])
110
+		digest.Write(handshakeBytes[RandomOffset+RandomLen:])
98 111
 
99
-	if subtle.ConstantTimeCompare(emptyRandom[:RandomLen-4], computed[:RandomLen-4]) != 1 {
100
-		return nil, ErrBadDigest
101
-	}
112
+		computed := digest.Sum(nil)
113
+
114
+		for i := range RandomLen {
115
+			computed[i] ^= hello.Random[i]
116
+		}
117
+
118
+		if subtle.ConstantTimeCompare(emptyRandom[:RandomLen-4], computed[:RandomLen-4]) != 1 {
119
+			continue
120
+		}
102 121
 
103
-	timestamp := int64(binary.LittleEndian.Uint32(computed[RandomLen-4:]))
104
-	createdAt := time.Unix(timestamp, 0)
122
+		timestamp := int64(binary.LittleEndian.Uint32(computed[RandomLen-4:]))
123
+		createdAt := time.Unix(timestamp, 0)
105 124
 
106
-	if tdiff := time.Since(createdAt).Abs(); tdiff > tolerateTimeSkewness {
107
-		return nil, fmt.Errorf("timestamp %q is too old %s", createdAt, tdiff)
125
+		if tdiff := time.Since(createdAt).Abs(); tdiff > tolerateTimeSkewness {
126
+			continue
127
+		}
128
+
129
+		return &ReadClientHelloResult{
130
+			Hello:        hello,
131
+			MatchedIndex: idx,
132
+		}, nil
108 133
 	}
109 134
 
110
-	return hello, nil
135
+	return nil, ErrBadDigest
111 136
 }
112 137
 
113 138
 func parseTLSHeader(r io.Reader) (io.Reader, error) {

+ 137
- 0
mtglib/internal/tls/fake/client_side_test.go Просмотреть файл

@@ -3,8 +3,11 @@ package fake_test
3 3
 import (
4 4
 	"bytes"
5 5
 	"encoding/binary"
6
+	"encoding/json"
6 7
 	"errors"
7 8
 	"io"
9
+	"os"
10
+	"path/filepath"
8 11
 	"testing"
9 12
 	"time"
10 13
 
@@ -393,3 +396,137 @@ func TestParseClientHelloSNI(t *testing.T) {
393 396
 	t.Parallel()
394 397
 	suite.Run(t, &ParseClientHelloSNITestSuite{})
395 398
 }
399
+
400
+// --- ReadClientHelloMulti tests ---
401
+
402
+type ReadClientHelloMultiTestSuite struct {
403
+	suite.Suite
404
+
405
+	secret mtglib.Secret
406
+}
407
+
408
+func (suite *ReadClientHelloMultiTestSuite) SetupSuite() {
409
+	parsed, err := mtglib.ParseSecret(
410
+		"ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d",
411
+	)
412
+	require.NoError(suite.T(), err)
413
+
414
+	suite.secret = parsed
415
+}
416
+
417
+func (suite *ReadClientHelloMultiTestSuite) loadSnapshot(name string) []byte {
418
+	data, err := os.ReadFile(filepath.Join("testdata", name))
419
+	require.NoError(suite.T(), err)
420
+
421
+	snapshot := &clientHelloSnapshot{}
422
+	require.NoError(suite.T(), json.Unmarshal(data, snapshot))
423
+
424
+	return snapshot.GetFull()
425
+}
426
+
427
+func (suite *ReadClientHelloMultiTestSuite) makeConn(data []byte) *parseClientHelloConnMock {
428
+	readBuf := &bytes.Buffer{}
429
+	readBuf.Write(data)
430
+
431
+	connMock := &parseClientHelloConnMock{
432
+		readBuf: readBuf,
433
+	}
434
+
435
+	connMock.
436
+		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
437
+		Twice().
438
+		Return(nil)
439
+
440
+	return connMock
441
+}
442
+
443
+func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex0() {
444
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
445
+	connMock := suite.makeConn(payload)
446
+	defer connMock.AssertExpectations(suite.T())
447
+
448
+	wrongSecret := mtglib.GenerateSecret("storage.googleapis.com")
449
+
450
+	result, err := fake.ReadClientHelloMulti(
451
+		connMock,
452
+		[][]byte{suite.secret.Key[:], wrongSecret.Key[:]},
453
+		suite.secret.Host,
454
+		TolerateTime,
455
+	)
456
+	suite.NoError(err)
457
+	suite.Equal(0, result.MatchedIndex)
458
+	suite.NotNil(result.Hello)
459
+}
460
+
461
+func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex1() {
462
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
463
+	connMock := suite.makeConn(payload)
464
+	defer connMock.AssertExpectations(suite.T())
465
+
466
+	wrongSecret := mtglib.GenerateSecret("storage.googleapis.com")
467
+
468
+	result, err := fake.ReadClientHelloMulti(
469
+		connMock,
470
+		[][]byte{wrongSecret.Key[:], suite.secret.Key[:]},
471
+		suite.secret.Host,
472
+		TolerateTime,
473
+	)
474
+	suite.NoError(err)
475
+	suite.Equal(1, result.MatchedIndex)
476
+	suite.NotNil(result.Hello)
477
+}
478
+
479
+func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex2() {
480
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
481
+	connMock := suite.makeConn(payload)
482
+	defer connMock.AssertExpectations(suite.T())
483
+
484
+	wrong1 := mtglib.GenerateSecret("storage.googleapis.com")
485
+	wrong2 := mtglib.GenerateSecret("storage.googleapis.com")
486
+
487
+	result, err := fake.ReadClientHelloMulti(
488
+		connMock,
489
+		[][]byte{wrong1.Key[:], wrong2.Key[:], suite.secret.Key[:]},
490
+		suite.secret.Host,
491
+		TolerateTime,
492
+	)
493
+	suite.NoError(err)
494
+	suite.Equal(2, result.MatchedIndex)
495
+	suite.NotNil(result.Hello)
496
+}
497
+
498
+func (suite *ReadClientHelloMultiTestSuite) TestNoMatchReturnsBadDigest() {
499
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
500
+	connMock := suite.makeConn(payload)
501
+	defer connMock.AssertExpectations(suite.T())
502
+
503
+	wrong1 := mtglib.GenerateSecret("storage.googleapis.com")
504
+	wrong2 := mtglib.GenerateSecret("storage.googleapis.com")
505
+
506
+	_, err := fake.ReadClientHelloMulti(
507
+		connMock,
508
+		[][]byte{wrong1.Key[:], wrong2.Key[:]},
509
+		suite.secret.Host,
510
+		TolerateTime,
511
+	)
512
+	suite.ErrorIs(err, fake.ErrBadDigest)
513
+}
514
+
515
+func (suite *ReadClientHelloMultiTestSuite) TestBadSnapshotReturnsBadDigest() {
516
+	payload := suite.loadSnapshot("client-hello-bad-fa2e46cdb33e2a1b.json")
517
+	connMock := suite.makeConn(payload)
518
+	defer connMock.AssertExpectations(suite.T())
519
+
520
+	_, err := fake.ReadClientHelloMulti(
521
+		connMock,
522
+		[][]byte{suite.secret.Key[:]},
523
+		suite.secret.Host,
524
+		TolerateTime,
525
+	)
526
+	suite.ErrorIs(err, fake.ErrBadDigest)
527
+}
528
+
529
+func TestReadClientHelloMulti(t *testing.T) {
530
+	t.Parallel()
531
+	suite.Run(t, &ReadClientHelloMultiTestSuite{})
532
+}

+ 64
- 14
mtglib/proxy.go Просмотреть файл

@@ -5,6 +5,7 @@ import (
5 5
 	"errors"
6 6
 	"fmt"
7 7
 	"net"
8
+	"sort"
8 9
 	"strconv"
9 10
 	"sync"
10 11
 	"time"
@@ -35,9 +36,10 @@ type Proxy struct {
35 36
 	telegram                    *dc.Telegram
36 37
 	configUpdater               *dc.PublicConfigUpdater
37 38
 	doppelGanger                *doppel.Ganger
38
-	clientObfuscatror           obfuscation.Obfuscator
39 39
 
40
-	secret          Secret
40
+	stats       *ProxyStats
41
+	secrets     []Secret
42
+	secretNames []string
41 43
 	network         Network
42 44
 	antiReplayCache AntiReplayCache
43 45
 	blocklist       IPBlocklist
@@ -49,7 +51,9 @@ type Proxy struct {
49 51
 // DomainFrontingAddress returns a host:port pair for a fronting domain.
50 52
 // If DomainFrontingIP is set, it is used instead of resolving the hostname.
51 53
 func (p *Proxy) DomainFrontingAddress() string {
52
-	host := p.secret.Host
54
+	// All secrets share the same host (enforced by validation),
55
+	// so we use the first one.
56
+	host := p.secrets[0].Host
53 57
 	if p.domainFrontingIP != "" {
54 58
 		host = p.domainFrontingIP
55 59
 	}
@@ -83,6 +87,11 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
83 87
 		return
84 88
 	}
85 89
 
90
+	p.stats.OnConnect(ctx.secretName)
91
+	p.stats.UpdateLastSeen(ctx.secretName)
92
+
93
+	defer p.stats.OnDisconnect(ctx.secretName)
94
+
86 95
 	clientConn, err := p.doppelGanger.NewConn(ctx.clientConn)
87 96
 	if err != nil {
88 97
 		ctx.logger.InfoError("cannot wrap into doppelganger connection", err)
@@ -102,11 +111,13 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
102 111
 		return
103 112
 	}
104 113
 
114
+	countedClientConn := newCountingConn(ctx.clientConn, p.stats, ctx.secretName)
115
+
105 116
 	relay.Relay(
106 117
 		ctx,
107 118
 		ctx.logger.Named("relay"),
108 119
 		ctx.telegramConn,
109
-		ctx.clientConn,
120
+		countedClientConn,
110 121
 	)
111 122
 }
112 123
 
@@ -175,10 +186,16 @@ func (p *Proxy) Shutdown() {
175 186
 func (p *Proxy) doFakeTLSHandshake(ctx *streamContext) bool {
176 187
 	rewind := newConnRewind(ctx.clientConn)
177 188
 
178
-	clientHello, err := fake.ReadClientHello(
189
+	// Build a slice of secret keys to try during HMAC validation.
190
+	secretKeys := make([][]byte, len(p.secrets))
191
+	for i := range p.secrets {
192
+		secretKeys[i] = p.secrets[i].Key[:]
193
+	}
194
+
195
+	result, err := fake.ReadClientHelloMulti(
179 196
 		rewind,
180
-		p.secret.Key[:],
181
-		p.secret.Host,
197
+		secretKeys,
198
+		p.secrets[0].Host,
182 199
 		p.tolerateTimeSkewness,
183 200
 	)
184 201
 	if err != nil {
@@ -187,17 +204,22 @@ func (p *Proxy) doFakeTLSHandshake(ctx *streamContext) bool {
187 204
 		return false
188 205
 	}
189 206
 
190
-	if p.antiReplayCache.SeenBefore(clientHello.SessionID) {
207
+	if p.antiReplayCache.SeenBefore(result.Hello.SessionID) {
191 208
 		p.logger.Warning("replay attack has been detected!")
192 209
 		p.eventStream.Send(p.ctx, NewEventReplayAttack(ctx.streamID))
193 210
 		p.doDomainFronting(ctx, rewind)
194 211
 		return false
195 212
 	}
196 213
 
214
+	matchedSecret := p.secrets[result.MatchedIndex]
215
+	ctx.matchedSecretKey = matchedSecret.Key[:]
216
+	ctx.secretName = p.secretNames[result.MatchedIndex]
217
+	ctx.logger = ctx.logger.BindStr("secret_name", ctx.secretName)
218
+
197 219
 	gangerNoise := p.doppelGanger.NoiseParams()
198 220
 	noiseParams := fake.NoiseParams{Mean: gangerNoise.Mean, Jitter: gangerNoise.Jitter}
199 221
 
200
-	if err := fake.SendServerHello(ctx.clientConn, p.secret.Key[:], clientHello, noiseParams); err != nil {
222
+	if err := fake.SendServerHello(ctx.clientConn, matchedSecret.Key[:], result.Hello, noiseParams); err != nil {
201 223
 		p.logger.InfoError("cannot send welcome packet", err)
202 224
 		return false
203 225
 	}
@@ -208,7 +230,12 @@ func (p *Proxy) doFakeTLSHandshake(ctx *streamContext) bool {
208 230
 }
209 231
 
210 232
 func (p *Proxy) doObfuscatedHandshake(ctx *streamContext) error {
211
-	dc, conn, err := p.clientObfuscatror.ReadHandshake(ctx.clientConn)
233
+	// Use the secret key that was matched during the FakeTLS handshake.
234
+	obfs := obfuscation.Obfuscator{
235
+		Secret: ctx.matchedSecretKey,
236
+	}
237
+
238
+	dc, conn, err := obfs.ReadHandshake(ctx.clientConn)
212 239
 	if err != nil {
213 240
 		return fmt.Errorf("cannot process client handshake: %w", err)
214 241
 	}
@@ -328,10 +355,36 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) {
328 355
 	logger := opts.getLogger("proxy")
329 356
 	updatersLogger := logger.Named("telegram-updaters")
330 357
 
358
+	secretsMap := opts.getSecrets()
359
+	secretNames := make([]string, 0, len(secretsMap))
360
+
361
+	for name := range secretsMap {
362
+		secretNames = append(secretNames, name)
363
+	}
364
+
365
+	sort.Strings(secretNames)
366
+
367
+	secretsList := make([]Secret, 0, len(secretsMap))
368
+
369
+	for _, name := range secretNames {
370
+		secretsList = append(secretsList, secretsMap[name])
371
+	}
372
+
373
+	stats := NewProxyStats()
374
+	for _, name := range secretNames {
375
+		stats.PreRegister(name)
376
+	}
377
+
378
+	if opts.APIBindTo != "" {
379
+		stats.StartServer(ctx, opts.APIBindTo, logger)
380
+	}
381
+
331 382
 	proxy := &Proxy{
332 383
 		ctx:                      ctx,
333 384
 		ctxCancel:                cancel,
334
-		secret:                   opts.Secret,
385
+		stats:                    stats,
386
+		secrets:                  secretsList,
387
+		secretNames:              secretNames,
335 388
 		network:                  opts.Network,
336 389
 		antiReplayCache:          opts.AntiReplayCache,
337 390
 		blocklist:                opts.IPBlocklist,
@@ -358,9 +411,6 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) {
358 411
 			updatersLogger.Named("public-config"),
359 412
 			opts.Network.MakeHTTPClient(nil),
360 413
 		),
361
-		clientObfuscatror: obfuscation.Obfuscator{
362
-			Secret: opts.Secret.Key[:],
363
-		},
364 414
 		domainFrontingProxyProtocol: opts.DomainFrontingProxyProtocol,
365 415
 	}
366 416
 

+ 48
- 3
mtglib/proxy_opts.go Просмотреть файл

@@ -1,6 +1,9 @@
1 1
 package mtglib
2 2
 
3
-import "time"
3
+import (
4
+	"fmt"
5
+	"time"
6
+)
4 7
 
5 8
 // ProxyOpts is a structure with settings to mtg proxy.
6 9
 //
@@ -9,9 +12,15 @@ import "time"
9 12
 type ProxyOpts struct {
10 13
 	// Secret defines a secret which should be used by a proxy.
11 14
 	//
12
-	// This is a mandatory setting.
15
+	// Deprecated: Use Secrets instead for multi-secret support.
16
+	// Kept for backward compatibility.
13 17
 	Secret Secret
14 18
 
19
+	// Secrets defines a map of named secrets which should be used by a proxy.
20
+	// If set, Secret is ignored. During FakeTLS handshake, each secret is
21
+	// tried until one validates.
22
+	Secrets map[string]Secret
23
+
15 24
 	// Network defines a network instance which should be used for all network
16 25
 	// communications made by proxies.
17 26
 	//
@@ -161,6 +170,11 @@ type ProxyOpts struct {
161 170
 	// DoppelGangerDRS defines if TLS Dynamic Record Sizing is active.
162 171
 	DoppelGangerDRS bool
163 172
 
173
+	// APIBindTo is the address to bind the stats HTTP API server to.
174
+	// If empty, the stats API server is not started.
175
+	//
176
+	// This is an optional setting.
177
+	APIBindTo string
164 178
 }
165 179
 
166 180
 func (p ProxyOpts) valid() error {
@@ -177,10 +191,41 @@ func (p ProxyOpts) valid() error {
177 191
 		return ErrEventStreamIsNotDefined
178 192
 	case p.Logger == nil:
179 193
 		return ErrLoggerIsNotDefined
180
-	case !p.Secret.Valid():
194
+	}
195
+
196
+	secrets := p.getSecrets()
197
+	if len(secrets) == 0 {
181 198
 		return ErrSecretInvalid
182 199
 	}
183 200
 
201
+	var host string
202
+
203
+	for _, s := range secrets {
204
+		if !s.Valid() {
205
+			return ErrSecretInvalid
206
+		}
207
+
208
+		if host == "" {
209
+			host = s.Host
210
+		} else if s.Host != host {
211
+			return fmt.Errorf("all secrets must use the same hostname, got %q and %q", host, s.Host)
212
+		}
213
+	}
214
+
215
+	return nil
216
+}
217
+
218
+// getSecrets returns the effective secrets map. If Secrets is populated, it is
219
+// returned directly. Otherwise the single Secret is wrapped in a map.
220
+func (p ProxyOpts) getSecrets() map[string]Secret {
221
+	if len(p.Secrets) > 0 {
222
+		return p.Secrets
223
+	}
224
+
225
+	if p.Secret.Valid() {
226
+		return map[string]Secret{"default": p.Secret}
227
+	}
228
+
184 229
 	return nil
185 230
 }
186 231
 

+ 179
- 0
mtglib/proxy_stats.go Просмотреть файл

@@ -0,0 +1,179 @@
1
+package mtglib
2
+
3
+import (
4
+	"context"
5
+	"encoding/json"
6
+	"net"
7
+	"net/http"
8
+	"sync"
9
+	"sync/atomic"
10
+	"time"
11
+)
12
+
13
+type secretStats struct {
14
+	connections atomic.Int64
15
+	bytesIn     atomic.Int64
16
+	bytesOut    atomic.Int64
17
+	lastSeen    atomic.Value // stores time.Time
18
+}
19
+
20
+// ProxyStats tracks per-secret connection stats with atomic counters.
21
+// Thread-safe for concurrent access from proxy goroutines.
22
+type ProxyStats struct {
23
+	mu        sync.RWMutex
24
+	users     map[string]*secretStats
25
+	startedAt time.Time
26
+}
27
+
28
+// NewProxyStats creates a new ProxyStats instance.
29
+func NewProxyStats() *ProxyStats {
30
+	return &ProxyStats{
31
+		users:     make(map[string]*secretStats),
32
+		startedAt: time.Now(),
33
+	}
34
+}
35
+
36
+func (s *ProxyStats) getOrCreate(name string) *secretStats {
37
+	s.mu.RLock()
38
+	st, ok := s.users[name]
39
+	s.mu.RUnlock()
40
+
41
+	if ok {
42
+		return st
43
+	}
44
+
45
+	s.mu.Lock()
46
+	defer s.mu.Unlock()
47
+
48
+	if st, ok = s.users[name]; ok {
49
+		return st
50
+	}
51
+
52
+	st = &secretStats{}
53
+	st.lastSeen.Store(time.Time{})
54
+	s.users[name] = st
55
+
56
+	return st
57
+}
58
+
59
+// PreRegister adds a secret name to the stats map so it appears in output
60
+// even if no connections have been made yet.
61
+func (s *ProxyStats) PreRegister(name string) {
62
+	s.getOrCreate(name)
63
+}
64
+
65
+// OnConnect increments the active connection count for the given secret.
66
+func (s *ProxyStats) OnConnect(name string) {
67
+	s.getOrCreate(name).connections.Add(1)
68
+}
69
+
70
+// OnDisconnect decrements the active connection count for the given secret.
71
+func (s *ProxyStats) OnDisconnect(name string) {
72
+	s.getOrCreate(name).connections.Add(-1)
73
+}
74
+
75
+// AddBytesIn adds to the bytes-in counter for the given secret.
76
+func (s *ProxyStats) AddBytesIn(name string, n int64) {
77
+	s.getOrCreate(name).bytesIn.Add(n)
78
+}
79
+
80
+// AddBytesOut adds to the bytes-out counter for the given secret.
81
+func (s *ProxyStats) AddBytesOut(name string, n int64) {
82
+	s.getOrCreate(name).bytesOut.Add(n)
83
+}
84
+
85
+// UpdateLastSeen sets the last-seen timestamp for the given secret to now.
86
+func (s *ProxyStats) UpdateLastSeen(name string) {
87
+	s.getOrCreate(name).lastSeen.Store(time.Now())
88
+}
89
+
90
+// StatsResponse is the JSON response for the stats endpoint.
91
+type StatsResponse struct {
92
+	StartedAt        time.Time                `json:"started_at"`
93
+	UptimeSeconds    int64                    `json:"uptime_seconds"`
94
+	TotalConnections int64                    `json:"total_connections"`
95
+	Users            map[string]UserStatsJSON `json:"users"`
96
+}
97
+
98
+// UserStatsJSON is the per-user portion of the stats JSON response.
99
+type UserStatsJSON struct {
100
+	Connections int64      `json:"connections"`
101
+	BytesIn     int64      `json:"bytes_in"`
102
+	BytesOut    int64      `json:"bytes_out"`
103
+	LastSeen    *time.Time `json:"last_seen"`
104
+}
105
+
106
+func (s *ProxyStats) ServeHTTP(w http.ResponseWriter, r *http.Request) {
107
+	s.mu.RLock()
108
+	defer s.mu.RUnlock()
109
+
110
+	var totalConns int64
111
+
112
+	users := make(map[string]UserStatsJSON, len(s.users))
113
+
114
+	for name, st := range s.users {
115
+		conns := st.connections.Load()
116
+		totalConns += conns
117
+
118
+		lastSeen := st.lastSeen.Load().(time.Time) //nolint: forcetypeassert
119
+		var lastSeenPtr *time.Time
120
+		if !lastSeen.IsZero() {
121
+			lastSeenPtr = &lastSeen
122
+		}
123
+
124
+		users[name] = UserStatsJSON{
125
+			Connections: conns,
126
+			BytesIn:     st.bytesIn.Load(),
127
+			BytesOut:    st.bytesOut.Load(),
128
+			LastSeen:    lastSeenPtr,
129
+		}
130
+	}
131
+
132
+	resp := StatsResponse{
133
+		StartedAt:        s.startedAt,
134
+		UptimeSeconds:    int64(time.Since(s.startedAt).Seconds()),
135
+		TotalConnections: totalConns,
136
+		Users:            users,
137
+	}
138
+
139
+	w.Header().Set("Content-Type", "application/json")
140
+
141
+	if err := json.NewEncoder(w).Encode(resp); err != nil {
142
+		http.Error(w, err.Error(), http.StatusInternalServerError)
143
+	}
144
+}
145
+
146
+// StartServer starts an HTTP server for the stats API in a background goroutine.
147
+// The server is shut down when ctx is cancelled.
148
+func (s *ProxyStats) StartServer(ctx context.Context, bindTo string, logger Logger) {
149
+	mux := http.NewServeMux()
150
+	mux.Handle("/stats", s)
151
+
152
+	srv := &http.Server{
153
+		Addr:    bindTo,
154
+		Handler: mux,
155
+	}
156
+
157
+	ln, err := net.Listen("tcp", bindTo)
158
+	if err != nil {
159
+		logger.WarningError("cannot start stats API listener", err)
160
+		return
161
+	}
162
+
163
+	go func() {
164
+		if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
165
+			logger.WarningError("stats API server error", err)
166
+		}
167
+	}()
168
+
169
+	go func() {
170
+		<-ctx.Done()
171
+
172
+		shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) //nolint: mnd
173
+		defer cancel()
174
+
175
+		srv.Shutdown(shutdownCtx) //nolint: errcheck
176
+	}()
177
+
178
+	logger.BindStr("bind", bindTo).Info("Stats API server started")
179
+}

+ 183
- 0
mtglib/proxy_stats_test.go Просмотреть файл

@@ -0,0 +1,183 @@
1
+package mtglib
2
+
3
+import (
4
+	"encoding/json"
5
+	"net/http"
6
+	"net/http/httptest"
7
+	"testing"
8
+	"time"
9
+
10
+	"github.com/stretchr/testify/assert"
11
+	"github.com/stretchr/testify/require"
12
+)
13
+
14
+func TestNewProxyStats(t *testing.T) {
15
+	t.Parallel()
16
+
17
+	stats := NewProxyStats()
18
+	assert.NotNil(t, stats)
19
+	assert.NotNil(t, stats.users)
20
+	assert.False(t, stats.startedAt.IsZero())
21
+}
22
+
23
+func TestPreRegister(t *testing.T) {
24
+	t.Parallel()
25
+
26
+	stats := NewProxyStats()
27
+	stats.PreRegister("alice")
28
+	stats.PreRegister("bob")
29
+
30
+	stats.mu.RLock()
31
+	defer stats.mu.RUnlock()
32
+
33
+	assert.Contains(t, stats.users, "alice")
34
+	assert.Contains(t, stats.users, "bob")
35
+	assert.Equal(t, int64(0), stats.users["alice"].connections.Load())
36
+}
37
+
38
+func TestOnConnectDisconnect(t *testing.T) {
39
+	t.Parallel()
40
+
41
+	stats := NewProxyStats()
42
+	stats.PreRegister("alice")
43
+
44
+	stats.OnConnect("alice")
45
+	assert.Equal(t, int64(1), stats.users["alice"].connections.Load())
46
+
47
+	stats.OnConnect("alice")
48
+	assert.Equal(t, int64(2), stats.users["alice"].connections.Load())
49
+
50
+	stats.OnDisconnect("alice")
51
+	assert.Equal(t, int64(1), stats.users["alice"].connections.Load())
52
+
53
+	stats.OnDisconnect("alice")
54
+	assert.Equal(t, int64(0), stats.users["alice"].connections.Load())
55
+}
56
+
57
+func TestAddBytes(t *testing.T) {
58
+	t.Parallel()
59
+
60
+	stats := NewProxyStats()
61
+	stats.PreRegister("alice")
62
+
63
+	stats.AddBytesIn("alice", 100)
64
+	stats.AddBytesIn("alice", 200)
65
+	stats.AddBytesOut("alice", 50)
66
+
67
+	st := stats.users["alice"]
68
+	assert.Equal(t, int64(300), st.bytesIn.Load())
69
+	assert.Equal(t, int64(50), st.bytesOut.Load())
70
+}
71
+
72
+func TestUpdateLastSeen(t *testing.T) {
73
+	t.Parallel()
74
+
75
+	stats := NewProxyStats()
76
+	stats.PreRegister("alice")
77
+
78
+	before := time.Now()
79
+	stats.UpdateLastSeen("alice")
80
+	after := time.Now()
81
+
82
+	lastSeen := stats.users["alice"].lastSeen.Load().(time.Time)
83
+	assert.False(t, lastSeen.Before(before))
84
+	assert.False(t, lastSeen.After(after))
85
+}
86
+
87
+func TestGetOrCreateLazy(t *testing.T) {
88
+	t.Parallel()
89
+
90
+	stats := NewProxyStats()
91
+
92
+	// getOrCreate should create a new entry on first access.
93
+	stats.OnConnect("new-user")
94
+	assert.Equal(t, int64(1), stats.users["new-user"].connections.Load())
95
+}
96
+
97
+func TestServeHTTPBasic(t *testing.T) {
98
+	t.Parallel()
99
+
100
+	stats := NewProxyStats()
101
+	stats.PreRegister("alice")
102
+	stats.PreRegister("bob")
103
+
104
+	stats.OnConnect("alice")
105
+	stats.OnConnect("alice")
106
+	stats.OnConnect("bob")
107
+	stats.AddBytesIn("alice", 1024)
108
+	stats.AddBytesOut("alice", 512)
109
+	stats.UpdateLastSeen("alice")
110
+
111
+	rec := httptest.NewRecorder()
112
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
113
+
114
+	stats.ServeHTTP(rec, req)
115
+
116
+	assert.Equal(t, http.StatusOK, rec.Code)
117
+	assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
118
+
119
+	var resp StatsResponse
120
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
121
+	require.NoError(t, err)
122
+
123
+	assert.Equal(t, int64(3), resp.TotalConnections)
124
+	assert.False(t, resp.StartedAt.IsZero())
125
+	assert.GreaterOrEqual(t, resp.UptimeSeconds, int64(0))
126
+
127
+	alice, ok := resp.Users["alice"]
128
+	require.True(t, ok)
129
+	assert.Equal(t, int64(2), alice.Connections)
130
+	assert.Equal(t, int64(1024), alice.BytesIn)
131
+	assert.Equal(t, int64(512), alice.BytesOut)
132
+	assert.NotNil(t, alice.LastSeen)
133
+
134
+	bob, ok := resp.Users["bob"]
135
+	require.True(t, ok)
136
+	assert.Equal(t, int64(1), bob.Connections)
137
+	assert.Equal(t, int64(0), bob.BytesIn)
138
+	assert.Equal(t, int64(0), bob.BytesOut)
139
+	assert.Nil(t, bob.LastSeen)
140
+}
141
+
142
+func TestServeHTTPEmpty(t *testing.T) {
143
+	t.Parallel()
144
+
145
+	stats := NewProxyStats()
146
+
147
+	rec := httptest.NewRecorder()
148
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
149
+
150
+	stats.ServeHTTP(rec, req)
151
+
152
+	assert.Equal(t, http.StatusOK, rec.Code)
153
+
154
+	var resp StatsResponse
155
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
156
+	require.NoError(t, err)
157
+
158
+	assert.Empty(t, resp.Users)
159
+	assert.Equal(t, int64(0), resp.TotalConnections)
160
+}
161
+
162
+func TestServeHTTPLastSeenZeroIsNull(t *testing.T) {
163
+	t.Parallel()
164
+
165
+	stats := NewProxyStats()
166
+	stats.PreRegister("alice")
167
+
168
+	rec := httptest.NewRecorder()
169
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
170
+
171
+	stats.ServeHTTP(rec, req)
172
+
173
+	var raw map[string]json.RawMessage
174
+	require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &raw))
175
+
176
+	var users map[string]json.RawMessage
177
+	require.NoError(t, json.Unmarshal(raw["users"], &users))
178
+
179
+	var aliceRaw map[string]json.RawMessage
180
+	require.NoError(t, json.Unmarshal(users["alice"], &aliceRaw))
181
+
182
+	assert.Equal(t, "null", string(aliceRaw["last_seen"]))
183
+}

+ 9
- 7
mtglib/stream_context.go Просмотреть файл

@@ -11,13 +11,15 @@ import (
11 11
 )
12 12
 
13 13
 type streamContext struct {
14
-	ctx          context.Context
15
-	ctxCancel    context.CancelFunc
16
-	clientConn   essentials.Conn
17
-	telegramConn essentials.Conn
18
-	streamID     string
19
-	dc           int
20
-	logger       Logger
14
+	ctx              context.Context
15
+	ctxCancel        context.CancelFunc
16
+	clientConn       essentials.Conn
17
+	telegramConn     essentials.Conn
18
+	streamID         string
19
+	dc               int
20
+	matchedSecretKey []byte
21
+	secretName       string
22
+	logger           Logger
21 23
 }
22 24
 
23 25
 func (s *streamContext) Deadline() (time.Time, bool) {

Загрузка…
Отмена
Сохранить