Przeglądaj źródła

Add multi-secret support and per-user stats API

Support multiple named secrets in [secrets] config section. During
FakeTLS handshake each secret is tried until HMAC validates. Matched
secret name is logged and used for per-user statistics tracking.

Built-in HTTP stats API (configurable via api-bind-to) exposes
GET /stats with per-user connection counts, bytes in/out, and
last-seen timestamps.

Backward compatible: single "secret" config key still works.
pull/434/head
Alexey Dolotov 1 miesiąc temu
rodzic
commit
1a450e3c45

+ 12
- 0
example.config.toml Wyświetl plik

20
 # should either be base64-encoded or starts with ee.
20
 # should either be base64-encoded or starts with ee.
21
 secret = "ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d"
21
 secret = "ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d"
22
 
22
 
23
+# For multi-user support, use the [secrets] section instead of "secret".
24
+# Each key is a user name, used for per-user stats tracking.
25
+# All secrets must use the same hostname.
26
+# [secrets]
27
+# alice = "ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d"
28
+# bob = "ee0123456789abcdef0123456789abcd9573746f726167652e676f6f676c65617069732e636f6d"
29
+
30
+# Host:port pair to bind the built-in stats HTTP API server.
31
+# GET /stats returns per-user connection counts and traffic.
32
+# If not set, the stats API is not started.
33
+# api-bind-to = "127.0.0.1:9090"
34
+
23
 # Host:port pair to run proxy on.
35
 # Host:port pair to run proxy on.
24
 bind-to = "0.0.0.0:3128"
36
 bind-to = "0.0.0.0:3128"
25
 
37
 

+ 56
- 10
internal/cli/access.go Wyświetl plik

6
 	"net"
6
 	"net"
7
 	"net/url"
7
 	"net/url"
8
 	"os"
8
 	"os"
9
+	"sort"
9
 	"strconv"
10
 	"strconv"
10
 	"sync"
11
 	"sync"
11
 
12
 
12
 	"github.com/9seconds/mtg/v2/internal/config"
13
 	"github.com/9seconds/mtg/v2/internal/config"
13
 	"github.com/9seconds/mtg/v2/internal/utils"
14
 	"github.com/9seconds/mtg/v2/internal/utils"
15
+	"github.com/9seconds/mtg/v2/mtglib"
14
 )
16
 )
15
 
17
 
18
+type accessResponseSecret struct {
19
+	Hex    string `json:"hex"`
20
+	Base64 string `json:"base64"`
21
+}
22
+
16
 type accessResponse struct {
23
 type accessResponse struct {
17
-	IPv4   *accessResponseURLs `json:"ipv4,omitempty"`
18
-	IPv6   *accessResponseURLs `json:"ipv6,omitempty"`
19
-	Secret struct {
20
-		Hex    string `json:"hex"`
21
-		Base64 string `json:"base64"`
22
-	} `json:"secret"`
24
+	IPv4    *accessResponseURLs             `json:"ipv4,omitempty"`
25
+	IPv6    *accessResponseURLs             `json:"ipv6,omitempty"`
26
+	Secret  accessResponseSecret            `json:"secret"`
27
+	Secrets map[string]accessResponseSecret `json:"secrets,omitempty"`
23
 }
28
 }
24
 
29
 
25
 type accessResponseURLs struct {
30
 type accessResponseURLs struct {
46
 	}
51
 	}
47
 
52
 
48
 	resp := &accessResponse{}
53
 	resp := &accessResponse{}
49
-	resp.Secret.Base64 = conf.Secret.Base64()
50
-	resp.Secret.Hex = conf.Secret.Hex()
54
+	secrets := conf.GetSecrets()
55
+
56
+	// Sort secret names for deterministic "first secret" selection.
57
+	sortedNames := make([]string, 0, len(secrets))
58
+	for name := range secrets {
59
+		sortedNames = append(sortedNames, name)
60
+	}
61
+
62
+	sort.Strings(sortedNames)
63
+
64
+	// For backward compatibility, populate the single Secret field with the
65
+	// first secret (sorted alphabetically).
66
+	if len(sortedNames) > 0 {
67
+		first := secrets[sortedNames[0]]
68
+		resp.Secret.Base64 = first.Base64()
69
+		resp.Secret.Hex = first.Hex()
70
+	}
71
+
72
+	if len(secrets) > 1 {
73
+		resp.Secrets = make(map[string]accessResponseSecret, len(secrets))
74
+
75
+		for name, s := range secrets {
76
+			resp.Secrets[name] = accessResponseSecret{
77
+				Hex:    s.Hex(),
78
+				Base64: s.Base64(),
79
+			}
80
+		}
81
+	}
51
 
82
 
52
 	ntw, err := makeNetwork(conf, version)
83
 	ntw, err := makeNetwork(conf, version)
53
 	if err != nil {
84
 	if err != nil {
114
 	values.Set("server", ip.String())
145
 	values.Set("server", ip.String())
115
 	values.Set("port", strconv.Itoa(int(portNo)))
146
 	values.Set("port", strconv.Itoa(int(portNo)))
116
 
147
 
148
+	// Use the first available secret (sorted) for URL generation.
149
+	secrets := conf.GetSecrets()
150
+	names := make([]string, 0, len(secrets))
151
+
152
+	for name := range secrets {
153
+		names = append(names, name)
154
+	}
155
+
156
+	sort.Strings(names)
157
+
158
+	var firstSecret mtglib.Secret
159
+	if len(names) > 0 {
160
+		firstSecret = secrets[names[0]]
161
+	}
162
+
117
 	if a.Hex {
163
 	if a.Hex {
118
-		values.Set("secret", conf.Secret.Hex())
164
+		values.Set("secret", firstSecret.Hex())
119
 	} else {
165
 	} else {
120
-		values.Set("secret", conf.Secret.Base64())
166
+		values.Set("secret", firstSecret.Base64())
121
 	}
167
 	}
122
 
168
 
123
 	urlQuery := values.Encode()
169
 	urlQuery := values.Encode()

+ 13
- 5
internal/cli/doctor.go Wyświetl plik

290
 	return err
290
 	return err
291
 }
291
 }
292
 
292
 
293
+func (d *Doctor) getFirstSecretHost() string {
294
+	for _, s := range d.conf.GetSecrets() {
295
+		return s.Host
296
+	}
297
+
298
+	return ""
299
+}
300
+
293
 func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
301
 func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
294
-	host := d.conf.Secret.Host
302
+	host := d.getFirstSecretHost()
295
 	if ip := d.conf.GetDomainFrontingIP(nil); ip != "" {
303
 	if ip := d.conf.GetDomainFrontingIP(nil); ip != "" {
296
 		host = ip
304
 		host = ip
297
 	}
305
 	}
323
 }
331
 }
324
 
332
 
325
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
333
 func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
326
-	addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host)
334
+	addresses, err := resolver.LookupIPAddr(context.Background(), d.getFirstSecretHost())
327
 	if err != nil {
335
 	if err != nil {
328
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
336
 		tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
329
-			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
337
+			"description": fmt.Sprintf("cannot resolve DNS name of %s", d.getFirstSecretHost()),
330
 			"error":       err,
338
 			"error":       err,
331
 		})
339
 		})
332
 		return false
340
 		return false
356
 			(ourIP6 != nil && value.IP.String() == ourIP6.String()) {
364
 			(ourIP6 != nil && value.IP.String() == ourIP6.String()) {
357
 			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
365
 			tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
358
 				"ip":       value.IP,
366
 				"ip":       value.IP,
359
-				"hostname": d.conf.Secret.Host,
367
+				"hostname": d.getFirstSecretHost(),
360
 			})
368
 			})
361
 			return true
369
 			return true
362
 		}
370
 		}
365
 	}
373
 	}
366
 
374
 
367
 	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
375
 	tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
368
-		"hostname": d.conf.Secret.Host,
376
+		"hostname": d.getFirstSecretHost(),
369
 		"resolved": strings.Join(strAddresses, ", "),
377
 		"resolved": strings.Join(strAddresses, ", "),
370
 		"ip4":      ourIP4,
378
 		"ip4":      ourIP4,
371
 		"ip6":      ourIP6,
379
 		"ip6":      ourIP6,

+ 3
- 1
internal/cli/run_proxy.go Wyświetl plik

253
 		IPAllowlist:     allowlist,
253
 		IPAllowlist:     allowlist,
254
 		EventStream:     eventStream,
254
 		EventStream:     eventStream,
255
 
255
 
256
-		Secret:                      conf.Secret,
256
+		Secrets:                     conf.GetSecrets(),
257
 		Concurrency:                 conf.GetConcurrency(mtglib.DefaultConcurrency),
257
 		Concurrency:                 conf.GetConcurrency(mtglib.DefaultConcurrency),
258
 		DomainFrontingPort:          conf.GetDomainFrontingPort(mtglib.DefaultDomainFrontingPort),
258
 		DomainFrontingPort:          conf.GetDomainFrontingPort(mtglib.DefaultDomainFrontingPort),
259
 		DomainFrontingIP:            conf.GetDomainFrontingIP(nil),
259
 		DomainFrontingIP:            conf.GetDomainFrontingIP(nil),
269
 		DoppelGangerPerRaid: conf.Defense.Doppelganger.Repeats.Get(mtglib.DoppelGangerPerRaid),
269
 		DoppelGangerPerRaid: conf.Defense.Doppelganger.Repeats.Get(mtglib.DoppelGangerPerRaid),
270
 		DoppelGangerEach:    conf.Defense.Doppelganger.UpdateEach.Get(mtglib.DoppelGangerEach),
270
 		DoppelGangerEach:    conf.Defense.Doppelganger.UpdateEach.Get(mtglib.DoppelGangerEach),
271
 		DoppelGangerDRS:     conf.Defense.Doppelganger.DRS.Get(false),
271
 		DoppelGangerDRS:     conf.Defense.Doppelganger.DRS.Get(false),
272
+
273
+		APIBindTo: conf.APIBindTo.Get(""),
272
 	}
274
 	}
273
 
275
 
274
 	proxy, err := mtglib.NewProxy(opts)
276
 	proxy, err := mtglib.NewProxy(opts)

+ 27
- 7
internal/config/config.go Wyświetl plik

23
 }
23
 }
24
 
24
 
25
 type Config struct {
25
 type Config struct {
26
-	Debug                       TypeBool        `json:"debug"`
27
-	AllowFallbackOnUnknownDC    TypeBool        `json:"allowFallbackOnUnknownDc"`
28
-	Secret                      mtglib.Secret   `json:"secret"`
29
-	BindTo                      TypeHostPort    `json:"bindTo"`
26
+	Debug                       TypeBool                   `json:"debug"`
27
+	AllowFallbackOnUnknownDC    TypeBool                   `json:"allowFallbackOnUnknownDc"`
28
+	Secret                      mtglib.Secret              `json:"secret"`
29
+	Secrets                     map[string]mtglib.Secret   `json:"secrets"`
30
+	BindTo                      TypeHostPort               `json:"bindTo"`
30
 	ProxyProtocolListener       TypeBool        `json:"proxyProtocolListener"`
31
 	ProxyProtocolListener       TypeBool        `json:"proxyProtocolListener"`
31
 	PreferIP                    TypePreferIP    `json:"preferIp"`
32
 	PreferIP                    TypePreferIP    `json:"preferIp"`
32
 	AutoUpdate                  TypeBool        `json:"autoUpdate"`
33
 	AutoUpdate                  TypeBool        `json:"autoUpdate"`
68
 		DNS     TypeDNSURI     `json:"dns"`
69
 		DNS     TypeDNSURI     `json:"dns"`
69
 		Proxies []TypeProxyURL `json:"proxies"`
70
 		Proxies []TypeProxyURL `json:"proxies"`
70
 	} `json:"network"`
71
 	} `json:"network"`
71
-	Stats struct {
72
+	APIBindTo TypeHostPort `json:"apiBindTo"`
73
+	Stats     struct {
72
 		StatsD struct {
74
 		StatsD struct {
73
 			Optional
75
 			Optional
74
 
76
 
125
 }
127
 }
126
 
128
 
127
 func (c *Config) Validate() error {
129
 func (c *Config) Validate() error {
128
-	if !c.Secret.Valid() {
129
-		return fmt.Errorf("invalid secret %s", c.Secret.String())
130
+	if len(c.Secrets) == 0 {
131
+		if !c.Secret.Valid() {
132
+			return fmt.Errorf("invalid secret %s", c.Secret.String())
133
+		}
134
+	} else {
135
+		for name, s := range c.Secrets {
136
+			if !s.Valid() {
137
+				return fmt.Errorf("invalid secret %q: %s", name, s.String())
138
+			}
139
+		}
130
 	}
140
 	}
131
 
141
 
132
 	if c.BindTo.Get("") == "" {
142
 	if c.BindTo.Get("") == "" {
136
 	return nil
146
 	return nil
137
 }
147
 }
138
 
148
 
149
+// GetSecrets returns all secrets as a map. If the new [secrets] section is used,
150
+// returns that map. Otherwise, wraps the single Secret as {"default": Secret}.
151
+func (c *Config) GetSecrets() map[string]mtglib.Secret {
152
+	if len(c.Secrets) > 0 {
153
+		return c.Secrets
154
+	}
155
+
156
+	return map[string]mtglib.Secret{"default": c.Secret}
157
+}
158
+
139
 func (c *Config) String() string {
159
 func (c *Config) String() string {
140
 	buf := &bytes.Buffer{}
160
 	buf := &bytes.Buffer{}
141
 	encoder := json.NewEncoder(buf)
161
 	encoder := json.NewEncoder(buf)

+ 5
- 3
internal/config/parse.go Wyświetl plik

11
 type tomlConfig struct {
11
 type tomlConfig struct {
12
 	Debug                       bool   `toml:"debug" json:"debug,omitempty"`
12
 	Debug                       bool   `toml:"debug" json:"debug,omitempty"`
13
 	AllowFallbackOnUnknownDC    bool   `toml:"allow-fallback-on-unknown-dc" json:"allowFallbackOnUnknownDc,omitempty"`
13
 	AllowFallbackOnUnknownDC    bool   `toml:"allow-fallback-on-unknown-dc" json:"allowFallbackOnUnknownDc,omitempty"`
14
-	Secret                      string `toml:"secret" json:"secret"`
15
-	BindTo                      string `toml:"bind-to" json:"bindTo"`
14
+	Secret                      string            `toml:"secret" json:"secret,omitempty"`
15
+	Secrets                     map[string]string `toml:"secrets" json:"secrets,omitempty"`
16
+	BindTo                      string            `toml:"bind-to" json:"bindTo"`
16
 	ProxyProtocolListener       bool   `toml:"proxy-protocol-listener" json:"proxyProtocolListener"`
17
 	ProxyProtocolListener       bool   `toml:"proxy-protocol-listener" json:"proxyProtocolListener"`
17
 	PreferIP                    string `toml:"prefer-ip" json:"preferIp,omitempty"`
18
 	PreferIP                    string `toml:"prefer-ip" json:"preferIp,omitempty"`
18
 	AutoUpdate                  bool   `toml:"auto-update" json:"autoUpdate,omitempty"`
19
 	AutoUpdate                  bool   `toml:"auto-update" json:"autoUpdate,omitempty"`
63
 		DNS     string   `toml:"dns" json:"dns,omitempty"`
64
 		DNS     string   `toml:"dns" json:"dns,omitempty"`
64
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`
65
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`
65
 	} `toml:"network" json:"network,omitempty"`
66
 	} `toml:"network" json:"network,omitempty"`
66
-	Stats struct {
67
+	APIBindTo string `toml:"api-bind-to" json:"apiBindTo,omitempty"`
68
+	Stats     struct {
67
 		StatsD struct {
69
 		StatsD struct {
68
 			Enabled      bool   `toml:"enabled" json:"enabled,omitempty"`
70
 			Enabled      bool   `toml:"enabled" json:"enabled,omitempty"`
69
 			Address      string `toml:"address" json:"address,omitempty"`
71
 			Address      string `toml:"address" json:"address,omitempty"`

+ 35
- 0
mtglib/counting_conn.go Wyświetl plik

1
+package mtglib
2
+
3
+import (
4
+	"github.com/9seconds/mtg/v2/essentials"
5
+)
6
+
7
+// countingConn wraps essentials.Conn and counts bytes through a cached
8
+// *secretStats pointer. The pointer is resolved once at construction time
9
+// so Read/Write never need to acquire a lock.
10
+type countingConn struct {
11
+	essentials.Conn
12
+	st *secretStats
13
+}
14
+
15
+func newCountingConn(conn essentials.Conn, stats *ProxyStats, secretName string) *countingConn {
16
+	return &countingConn{Conn: conn, st: stats.getOrCreate(secretName)}
17
+}
18
+
19
+func (c *countingConn) Read(p []byte) (int, error) {
20
+	n, err := c.Conn.Read(p)
21
+	if n > 0 {
22
+		c.st.bytesIn.Add(int64(n))
23
+	}
24
+
25
+	return n, err
26
+}
27
+
28
+func (c *countingConn) Write(p []byte) (int, error) {
29
+	n, err := c.Conn.Write(p)
30
+	if n > 0 {
31
+		c.st.bytesOut.Add(int64(n))
32
+	}
33
+
34
+	return n, err
35
+}

+ 52
- 27
mtglib/internal/tls/fake/client_side.go Wyświetl plik

36
 	CipherSuite uint16
36
 	CipherSuite uint16
37
 }
37
 }
38
 
38
 
39
+// ReadClientHelloResult contains the parsed ClientHello and the index of the
40
+// secret that matched the HMAC validation.
41
+type ReadClientHelloResult struct {
42
+	Hello        *ClientHello
43
+	MatchedIndex int
44
+}
45
+
39
 func ReadClientHello(
46
 func ReadClientHello(
40
 	conn net.Conn,
47
 	conn net.Conn,
41
 	secret []byte,
48
 	secret []byte,
42
 	hostname string,
49
 	hostname string,
43
 	tolerateTimeSkewness time.Duration,
50
 	tolerateTimeSkewness time.Duration,
44
 ) (*ClientHello, error) {
51
 ) (*ClientHello, error) {
52
+	result, err := ReadClientHelloMulti(conn, [][]byte{secret}, hostname, tolerateTimeSkewness)
53
+	if err != nil {
54
+		return nil, err
55
+	}
56
+
57
+	return result.Hello, nil
58
+}
59
+
60
+// ReadClientHelloMulti is like ReadClientHello but accepts multiple secrets.
61
+// It tries each secret until one validates the HMAC. On success it returns
62
+// the ClientHello and the index of the matched secret.
63
+func ReadClientHelloMulti(
64
+	conn net.Conn,
65
+	secrets [][]byte,
66
+	hostname string,
67
+	tolerateTimeSkewness time.Duration,
68
+) (*ReadClientHelloResult, error) {
45
 	if err := conn.SetReadDeadline(time.Now().Add(ClientHelloReadTimeout)); err != nil {
69
 	if err := conn.SetReadDeadline(time.Now().Add(ClientHelloReadTimeout)); err != nil {
46
 		return nil, fmt.Errorf("cannot set read deadline: %w", err)
70
 		return nil, fmt.Errorf("cannot set read deadline: %w", err)
47
 	}
71
 	}
48
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
72
 	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
49
 
73
 
50
-	// This is how FakeTLS is organized:
51
-	//  1. We create sha256 HMAC with a given secret
52
-	//  2. We dump there a whole TLS frame except of the fact that random
53
-	//     is filled with all zeroes
54
-	//  3. Digest is computed. This digest should be XORed with
55
-	//     original client random
56
-	//  4. New digest should be all 0 except of last 4 bytes
57
-	//  5. Last 4 bytes are little endian uint32 of UNIX timestamp when
58
-	//     this message was created.
59
 	handshakeCopyBuf := &bytes.Buffer{}
74
 	handshakeCopyBuf := &bytes.Buffer{}
60
 	reader := io.TeeReader(conn, handshakeCopyBuf)
75
 	reader := io.TeeReader(conn, handshakeCopyBuf)
61
 
76
 
83
 		return nil, fmt.Errorf("cannot find %s in %v", hostname, sniHostnames)
98
 		return nil, fmt.Errorf("cannot find %s in %v", hostname, sniHostnames)
84
 	}
99
 	}
85
 
100
 
86
-	digest := hmac.New(sha256.New, secret)
87
-	// we write a copy of the handshake with client random all nullified.
88
-	digest.Write(handshakeCopyBuf.Next(RandomOffset))
89
-	handshakeCopyBuf.Next(RandomLen)
90
-	digest.Write(emptyRandom[:])
91
-	digest.Write(handshakeCopyBuf.Bytes())
101
+	// Save the handshake bytes so we can reuse them for each secret attempt.
102
+	handshakeBytes := handshakeCopyBuf.Bytes()
92
 
103
 
93
-	computed := digest.Sum(nil)
104
+	for idx, secret := range secrets {
105
+		digest := hmac.New(sha256.New, secret)
94
 
106
 
95
-	for i := range RandomLen {
96
-		computed[i] ^= hello.Random[i]
97
-	}
107
+		// Write the handshake with client random all nullified.
108
+		digest.Write(handshakeBytes[:RandomOffset])
109
+		digest.Write(emptyRandom[:])
110
+		digest.Write(handshakeBytes[RandomOffset+RandomLen:])
98
 
111
 
99
-	if subtle.ConstantTimeCompare(emptyRandom[:RandomLen-4], computed[:RandomLen-4]) != 1 {
100
-		return nil, ErrBadDigest
101
-	}
112
+		computed := digest.Sum(nil)
113
+
114
+		for i := range RandomLen {
115
+			computed[i] ^= hello.Random[i]
116
+		}
117
+
118
+		if subtle.ConstantTimeCompare(emptyRandom[:RandomLen-4], computed[:RandomLen-4]) != 1 {
119
+			continue
120
+		}
102
 
121
 
103
-	timestamp := int64(binary.LittleEndian.Uint32(computed[RandomLen-4:]))
104
-	createdAt := time.Unix(timestamp, 0)
122
+		timestamp := int64(binary.LittleEndian.Uint32(computed[RandomLen-4:]))
123
+		createdAt := time.Unix(timestamp, 0)
105
 
124
 
106
-	if tdiff := time.Since(createdAt).Abs(); tdiff > tolerateTimeSkewness {
107
-		return nil, fmt.Errorf("timestamp %q is too old %s", createdAt, tdiff)
125
+		if tdiff := time.Since(createdAt).Abs(); tdiff > tolerateTimeSkewness {
126
+			continue
127
+		}
128
+
129
+		return &ReadClientHelloResult{
130
+			Hello:        hello,
131
+			MatchedIndex: idx,
132
+		}, nil
108
 	}
133
 	}
109
 
134
 
110
-	return hello, nil
135
+	return nil, ErrBadDigest
111
 }
136
 }
112
 
137
 
113
 func parseTLSHeader(r io.Reader) (io.Reader, error) {
138
 func parseTLSHeader(r io.Reader) (io.Reader, error) {

+ 137
- 0
mtglib/internal/tls/fake/client_side_test.go Wyświetl plik

3
 import (
3
 import (
4
 	"bytes"
4
 	"bytes"
5
 	"encoding/binary"
5
 	"encoding/binary"
6
+	"encoding/json"
6
 	"errors"
7
 	"errors"
7
 	"io"
8
 	"io"
9
+	"os"
10
+	"path/filepath"
8
 	"testing"
11
 	"testing"
9
 	"time"
12
 	"time"
10
 
13
 
393
 	t.Parallel()
396
 	t.Parallel()
394
 	suite.Run(t, &ParseClientHelloSNITestSuite{})
397
 	suite.Run(t, &ParseClientHelloSNITestSuite{})
395
 }
398
 }
399
+
400
+// --- ReadClientHelloMulti tests ---
401
+
402
+type ReadClientHelloMultiTestSuite struct {
403
+	suite.Suite
404
+
405
+	secret mtglib.Secret
406
+}
407
+
408
+func (suite *ReadClientHelloMultiTestSuite) SetupSuite() {
409
+	parsed, err := mtglib.ParseSecret(
410
+		"ee367a189aee18fa31c190054efd4a8e9573746f726167652e676f6f676c65617069732e636f6d",
411
+	)
412
+	require.NoError(suite.T(), err)
413
+
414
+	suite.secret = parsed
415
+}
416
+
417
+func (suite *ReadClientHelloMultiTestSuite) loadSnapshot(name string) []byte {
418
+	data, err := os.ReadFile(filepath.Join("testdata", name))
419
+	require.NoError(suite.T(), err)
420
+
421
+	snapshot := &clientHelloSnapshot{}
422
+	require.NoError(suite.T(), json.Unmarshal(data, snapshot))
423
+
424
+	return snapshot.GetFull()
425
+}
426
+
427
+func (suite *ReadClientHelloMultiTestSuite) makeConn(data []byte) *parseClientHelloConnMock {
428
+	readBuf := &bytes.Buffer{}
429
+	readBuf.Write(data)
430
+
431
+	connMock := &parseClientHelloConnMock{
432
+		readBuf: readBuf,
433
+	}
434
+
435
+	connMock.
436
+		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
437
+		Twice().
438
+		Return(nil)
439
+
440
+	return connMock
441
+}
442
+
443
+func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex0() {
444
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
445
+	connMock := suite.makeConn(payload)
446
+	defer connMock.AssertExpectations(suite.T())
447
+
448
+	wrongSecret := mtglib.GenerateSecret("storage.googleapis.com")
449
+
450
+	result, err := fake.ReadClientHelloMulti(
451
+		connMock,
452
+		[][]byte{suite.secret.Key[:], wrongSecret.Key[:]},
453
+		suite.secret.Host,
454
+		TolerateTime,
455
+	)
456
+	suite.NoError(err)
457
+	suite.Equal(0, result.MatchedIndex)
458
+	suite.NotNil(result.Hello)
459
+}
460
+
461
+func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex1() {
462
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
463
+	connMock := suite.makeConn(payload)
464
+	defer connMock.AssertExpectations(suite.T())
465
+
466
+	wrongSecret := mtglib.GenerateSecret("storage.googleapis.com")
467
+
468
+	result, err := fake.ReadClientHelloMulti(
469
+		connMock,
470
+		[][]byte{wrongSecret.Key[:], suite.secret.Key[:]},
471
+		suite.secret.Host,
472
+		TolerateTime,
473
+	)
474
+	suite.NoError(err)
475
+	suite.Equal(1, result.MatchedIndex)
476
+	suite.NotNil(result.Hello)
477
+}
478
+
479
+func (suite *ReadClientHelloMultiTestSuite) TestMatchesCorrectSecretAtIndex2() {
480
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
481
+	connMock := suite.makeConn(payload)
482
+	defer connMock.AssertExpectations(suite.T())
483
+
484
+	wrong1 := mtglib.GenerateSecret("storage.googleapis.com")
485
+	wrong2 := mtglib.GenerateSecret("storage.googleapis.com")
486
+
487
+	result, err := fake.ReadClientHelloMulti(
488
+		connMock,
489
+		[][]byte{wrong1.Key[:], wrong2.Key[:], suite.secret.Key[:]},
490
+		suite.secret.Host,
491
+		TolerateTime,
492
+	)
493
+	suite.NoError(err)
494
+	suite.Equal(2, result.MatchedIndex)
495
+	suite.NotNil(result.Hello)
496
+}
497
+
498
+func (suite *ReadClientHelloMultiTestSuite) TestNoMatchReturnsBadDigest() {
499
+	payload := suite.loadSnapshot("client-hello-ok-19dfe38384b9884b.json")
500
+	connMock := suite.makeConn(payload)
501
+	defer connMock.AssertExpectations(suite.T())
502
+
503
+	wrong1 := mtglib.GenerateSecret("storage.googleapis.com")
504
+	wrong2 := mtglib.GenerateSecret("storage.googleapis.com")
505
+
506
+	_, err := fake.ReadClientHelloMulti(
507
+		connMock,
508
+		[][]byte{wrong1.Key[:], wrong2.Key[:]},
509
+		suite.secret.Host,
510
+		TolerateTime,
511
+	)
512
+	suite.ErrorIs(err, fake.ErrBadDigest)
513
+}
514
+
515
+func (suite *ReadClientHelloMultiTestSuite) TestBadSnapshotReturnsBadDigest() {
516
+	payload := suite.loadSnapshot("client-hello-bad-fa2e46cdb33e2a1b.json")
517
+	connMock := suite.makeConn(payload)
518
+	defer connMock.AssertExpectations(suite.T())
519
+
520
+	_, err := fake.ReadClientHelloMulti(
521
+		connMock,
522
+		[][]byte{suite.secret.Key[:]},
523
+		suite.secret.Host,
524
+		TolerateTime,
525
+	)
526
+	suite.ErrorIs(err, fake.ErrBadDigest)
527
+}
528
+
529
+func TestReadClientHelloMulti(t *testing.T) {
530
+	t.Parallel()
531
+	suite.Run(t, &ReadClientHelloMultiTestSuite{})
532
+}

+ 64
- 14
mtglib/proxy.go Wyświetl plik

5
 	"errors"
5
 	"errors"
6
 	"fmt"
6
 	"fmt"
7
 	"net"
7
 	"net"
8
+	"sort"
8
 	"strconv"
9
 	"strconv"
9
 	"sync"
10
 	"sync"
10
 	"time"
11
 	"time"
35
 	telegram                    *dc.Telegram
36
 	telegram                    *dc.Telegram
36
 	configUpdater               *dc.PublicConfigUpdater
37
 	configUpdater               *dc.PublicConfigUpdater
37
 	doppelGanger                *doppel.Ganger
38
 	doppelGanger                *doppel.Ganger
38
-	clientObfuscatror           obfuscation.Obfuscator
39
 
39
 
40
-	secret          Secret
40
+	stats       *ProxyStats
41
+	secrets     []Secret
42
+	secretNames []string
41
 	network         Network
43
 	network         Network
42
 	antiReplayCache AntiReplayCache
44
 	antiReplayCache AntiReplayCache
43
 	blocklist       IPBlocklist
45
 	blocklist       IPBlocklist
49
 // DomainFrontingAddress returns a host:port pair for a fronting domain.
51
 // DomainFrontingAddress returns a host:port pair for a fronting domain.
50
 // If DomainFrontingIP is set, it is used instead of resolving the hostname.
52
 // If DomainFrontingIP is set, it is used instead of resolving the hostname.
51
 func (p *Proxy) DomainFrontingAddress() string {
53
 func (p *Proxy) DomainFrontingAddress() string {
52
-	host := p.secret.Host
54
+	// All secrets share the same host (enforced by validation),
55
+	// so we use the first one.
56
+	host := p.secrets[0].Host
53
 	if p.domainFrontingIP != "" {
57
 	if p.domainFrontingIP != "" {
54
 		host = p.domainFrontingIP
58
 		host = p.domainFrontingIP
55
 	}
59
 	}
83
 		return
87
 		return
84
 	}
88
 	}
85
 
89
 
90
+	p.stats.OnConnect(ctx.secretName)
91
+	p.stats.UpdateLastSeen(ctx.secretName)
92
+
93
+	defer p.stats.OnDisconnect(ctx.secretName)
94
+
86
 	clientConn, err := p.doppelGanger.NewConn(ctx.clientConn)
95
 	clientConn, err := p.doppelGanger.NewConn(ctx.clientConn)
87
 	if err != nil {
96
 	if err != nil {
88
 		ctx.logger.InfoError("cannot wrap into doppelganger connection", err)
97
 		ctx.logger.InfoError("cannot wrap into doppelganger connection", err)
102
 		return
111
 		return
103
 	}
112
 	}
104
 
113
 
114
+	countedClientConn := newCountingConn(ctx.clientConn, p.stats, ctx.secretName)
115
+
105
 	relay.Relay(
116
 	relay.Relay(
106
 		ctx,
117
 		ctx,
107
 		ctx.logger.Named("relay"),
118
 		ctx.logger.Named("relay"),
108
 		ctx.telegramConn,
119
 		ctx.telegramConn,
109
-		ctx.clientConn,
120
+		countedClientConn,
110
 	)
121
 	)
111
 }
122
 }
112
 
123
 
175
 func (p *Proxy) doFakeTLSHandshake(ctx *streamContext) bool {
186
 func (p *Proxy) doFakeTLSHandshake(ctx *streamContext) bool {
176
 	rewind := newConnRewind(ctx.clientConn)
187
 	rewind := newConnRewind(ctx.clientConn)
177
 
188
 
178
-	clientHello, err := fake.ReadClientHello(
189
+	// Build a slice of secret keys to try during HMAC validation.
190
+	secretKeys := make([][]byte, len(p.secrets))
191
+	for i := range p.secrets {
192
+		secretKeys[i] = p.secrets[i].Key[:]
193
+	}
194
+
195
+	result, err := fake.ReadClientHelloMulti(
179
 		rewind,
196
 		rewind,
180
-		p.secret.Key[:],
181
-		p.secret.Host,
197
+		secretKeys,
198
+		p.secrets[0].Host,
182
 		p.tolerateTimeSkewness,
199
 		p.tolerateTimeSkewness,
183
 	)
200
 	)
184
 	if err != nil {
201
 	if err != nil {
187
 		return false
204
 		return false
188
 	}
205
 	}
189
 
206
 
190
-	if p.antiReplayCache.SeenBefore(clientHello.SessionID) {
207
+	if p.antiReplayCache.SeenBefore(result.Hello.SessionID) {
191
 		p.logger.Warning("replay attack has been detected!")
208
 		p.logger.Warning("replay attack has been detected!")
192
 		p.eventStream.Send(p.ctx, NewEventReplayAttack(ctx.streamID))
209
 		p.eventStream.Send(p.ctx, NewEventReplayAttack(ctx.streamID))
193
 		p.doDomainFronting(ctx, rewind)
210
 		p.doDomainFronting(ctx, rewind)
194
 		return false
211
 		return false
195
 	}
212
 	}
196
 
213
 
214
+	matchedSecret := p.secrets[result.MatchedIndex]
215
+	ctx.matchedSecretKey = matchedSecret.Key[:]
216
+	ctx.secretName = p.secretNames[result.MatchedIndex]
217
+	ctx.logger = ctx.logger.BindStr("secret_name", ctx.secretName)
218
+
197
 	gangerNoise := p.doppelGanger.NoiseParams()
219
 	gangerNoise := p.doppelGanger.NoiseParams()
198
 	noiseParams := fake.NoiseParams{Mean: gangerNoise.Mean, Jitter: gangerNoise.Jitter}
220
 	noiseParams := fake.NoiseParams{Mean: gangerNoise.Mean, Jitter: gangerNoise.Jitter}
199
 
221
 
200
-	if err := fake.SendServerHello(ctx.clientConn, p.secret.Key[:], clientHello, noiseParams); err != nil {
222
+	if err := fake.SendServerHello(ctx.clientConn, matchedSecret.Key[:], result.Hello, noiseParams); err != nil {
201
 		p.logger.InfoError("cannot send welcome packet", err)
223
 		p.logger.InfoError("cannot send welcome packet", err)
202
 		return false
224
 		return false
203
 	}
225
 	}
208
 }
230
 }
209
 
231
 
210
 func (p *Proxy) doObfuscatedHandshake(ctx *streamContext) error {
232
 func (p *Proxy) doObfuscatedHandshake(ctx *streamContext) error {
211
-	dc, conn, err := p.clientObfuscatror.ReadHandshake(ctx.clientConn)
233
+	// Use the secret key that was matched during the FakeTLS handshake.
234
+	obfs := obfuscation.Obfuscator{
235
+		Secret: ctx.matchedSecretKey,
236
+	}
237
+
238
+	dc, conn, err := obfs.ReadHandshake(ctx.clientConn)
212
 	if err != nil {
239
 	if err != nil {
213
 		return fmt.Errorf("cannot process client handshake: %w", err)
240
 		return fmt.Errorf("cannot process client handshake: %w", err)
214
 	}
241
 	}
328
 	logger := opts.getLogger("proxy")
355
 	logger := opts.getLogger("proxy")
329
 	updatersLogger := logger.Named("telegram-updaters")
356
 	updatersLogger := logger.Named("telegram-updaters")
330
 
357
 
358
+	secretsMap := opts.getSecrets()
359
+	secretNames := make([]string, 0, len(secretsMap))
360
+
361
+	for name := range secretsMap {
362
+		secretNames = append(secretNames, name)
363
+	}
364
+
365
+	sort.Strings(secretNames)
366
+
367
+	secretsList := make([]Secret, 0, len(secretsMap))
368
+
369
+	for _, name := range secretNames {
370
+		secretsList = append(secretsList, secretsMap[name])
371
+	}
372
+
373
+	stats := NewProxyStats()
374
+	for _, name := range secretNames {
375
+		stats.PreRegister(name)
376
+	}
377
+
378
+	if opts.APIBindTo != "" {
379
+		stats.StartServer(ctx, opts.APIBindTo, logger)
380
+	}
381
+
331
 	proxy := &Proxy{
382
 	proxy := &Proxy{
332
 		ctx:                      ctx,
383
 		ctx:                      ctx,
333
 		ctxCancel:                cancel,
384
 		ctxCancel:                cancel,
334
-		secret:                   opts.Secret,
385
+		stats:                    stats,
386
+		secrets:                  secretsList,
387
+		secretNames:              secretNames,
335
 		network:                  opts.Network,
388
 		network:                  opts.Network,
336
 		antiReplayCache:          opts.AntiReplayCache,
389
 		antiReplayCache:          opts.AntiReplayCache,
337
 		blocklist:                opts.IPBlocklist,
390
 		blocklist:                opts.IPBlocklist,
358
 			updatersLogger.Named("public-config"),
411
 			updatersLogger.Named("public-config"),
359
 			opts.Network.MakeHTTPClient(nil),
412
 			opts.Network.MakeHTTPClient(nil),
360
 		),
413
 		),
361
-		clientObfuscatror: obfuscation.Obfuscator{
362
-			Secret: opts.Secret.Key[:],
363
-		},
364
 		domainFrontingProxyProtocol: opts.DomainFrontingProxyProtocol,
414
 		domainFrontingProxyProtocol: opts.DomainFrontingProxyProtocol,
365
 	}
415
 	}
366
 
416
 

+ 48
- 3
mtglib/proxy_opts.go Wyświetl plik

1
 package mtglib
1
 package mtglib
2
 
2
 
3
-import "time"
3
+import (
4
+	"fmt"
5
+	"time"
6
+)
4
 
7
 
5
 // ProxyOpts is a structure with settings to mtg proxy.
8
 // ProxyOpts is a structure with settings to mtg proxy.
6
 //
9
 //
9
 type ProxyOpts struct {
12
 type ProxyOpts struct {
10
 	// Secret defines a secret which should be used by a proxy.
13
 	// Secret defines a secret which should be used by a proxy.
11
 	//
14
 	//
12
-	// This is a mandatory setting.
15
+	// Deprecated: Use Secrets instead for multi-secret support.
16
+	// Kept for backward compatibility.
13
 	Secret Secret
17
 	Secret Secret
14
 
18
 
19
+	// Secrets defines a map of named secrets which should be used by a proxy.
20
+	// If set, Secret is ignored. During FakeTLS handshake, each secret is
21
+	// tried until one validates.
22
+	Secrets map[string]Secret
23
+
15
 	// Network defines a network instance which should be used for all network
24
 	// Network defines a network instance which should be used for all network
16
 	// communications made by proxies.
25
 	// communications made by proxies.
17
 	//
26
 	//
161
 	// DoppelGangerDRS defines if TLS Dynamic Record Sizing is active.
170
 	// DoppelGangerDRS defines if TLS Dynamic Record Sizing is active.
162
 	DoppelGangerDRS bool
171
 	DoppelGangerDRS bool
163
 
172
 
173
+	// APIBindTo is the address to bind the stats HTTP API server to.
174
+	// If empty, the stats API server is not started.
175
+	//
176
+	// This is an optional setting.
177
+	APIBindTo string
164
 }
178
 }
165
 
179
 
166
 func (p ProxyOpts) valid() error {
180
 func (p ProxyOpts) valid() error {
177
 		return ErrEventStreamIsNotDefined
191
 		return ErrEventStreamIsNotDefined
178
 	case p.Logger == nil:
192
 	case p.Logger == nil:
179
 		return ErrLoggerIsNotDefined
193
 		return ErrLoggerIsNotDefined
180
-	case !p.Secret.Valid():
194
+	}
195
+
196
+	secrets := p.getSecrets()
197
+	if len(secrets) == 0 {
181
 		return ErrSecretInvalid
198
 		return ErrSecretInvalid
182
 	}
199
 	}
183
 
200
 
201
+	var host string
202
+
203
+	for _, s := range secrets {
204
+		if !s.Valid() {
205
+			return ErrSecretInvalid
206
+		}
207
+
208
+		if host == "" {
209
+			host = s.Host
210
+		} else if s.Host != host {
211
+			return fmt.Errorf("all secrets must use the same hostname, got %q and %q", host, s.Host)
212
+		}
213
+	}
214
+
215
+	return nil
216
+}
217
+
218
+// getSecrets returns the effective secrets map. If Secrets is populated, it is
219
+// returned directly. Otherwise the single Secret is wrapped in a map.
220
+func (p ProxyOpts) getSecrets() map[string]Secret {
221
+	if len(p.Secrets) > 0 {
222
+		return p.Secrets
223
+	}
224
+
225
+	if p.Secret.Valid() {
226
+		return map[string]Secret{"default": p.Secret}
227
+	}
228
+
184
 	return nil
229
 	return nil
185
 }
230
 }
186
 
231
 

+ 179
- 0
mtglib/proxy_stats.go Wyświetl plik

1
+package mtglib
2
+
3
+import (
4
+	"context"
5
+	"encoding/json"
6
+	"net"
7
+	"net/http"
8
+	"sync"
9
+	"sync/atomic"
10
+	"time"
11
+)
12
+
13
+type secretStats struct {
14
+	connections atomic.Int64
15
+	bytesIn     atomic.Int64
16
+	bytesOut    atomic.Int64
17
+	lastSeen    atomic.Value // stores time.Time
18
+}
19
+
20
+// ProxyStats tracks per-secret connection stats with atomic counters.
21
+// Thread-safe for concurrent access from proxy goroutines.
22
+type ProxyStats struct {
23
+	mu        sync.RWMutex
24
+	users     map[string]*secretStats
25
+	startedAt time.Time
26
+}
27
+
28
+// NewProxyStats creates a new ProxyStats instance.
29
+func NewProxyStats() *ProxyStats {
30
+	return &ProxyStats{
31
+		users:     make(map[string]*secretStats),
32
+		startedAt: time.Now(),
33
+	}
34
+}
35
+
36
+func (s *ProxyStats) getOrCreate(name string) *secretStats {
37
+	s.mu.RLock()
38
+	st, ok := s.users[name]
39
+	s.mu.RUnlock()
40
+
41
+	if ok {
42
+		return st
43
+	}
44
+
45
+	s.mu.Lock()
46
+	defer s.mu.Unlock()
47
+
48
+	if st, ok = s.users[name]; ok {
49
+		return st
50
+	}
51
+
52
+	st = &secretStats{}
53
+	st.lastSeen.Store(time.Time{})
54
+	s.users[name] = st
55
+
56
+	return st
57
+}
58
+
59
+// PreRegister adds a secret name to the stats map so it appears in output
60
+// even if no connections have been made yet.
61
+func (s *ProxyStats) PreRegister(name string) {
62
+	s.getOrCreate(name)
63
+}
64
+
65
+// OnConnect increments the active connection count for the given secret.
66
+func (s *ProxyStats) OnConnect(name string) {
67
+	s.getOrCreate(name).connections.Add(1)
68
+}
69
+
70
+// OnDisconnect decrements the active connection count for the given secret.
71
+func (s *ProxyStats) OnDisconnect(name string) {
72
+	s.getOrCreate(name).connections.Add(-1)
73
+}
74
+
75
+// AddBytesIn adds to the bytes-in counter for the given secret.
76
+func (s *ProxyStats) AddBytesIn(name string, n int64) {
77
+	s.getOrCreate(name).bytesIn.Add(n)
78
+}
79
+
80
+// AddBytesOut adds to the bytes-out counter for the given secret.
81
+func (s *ProxyStats) AddBytesOut(name string, n int64) {
82
+	s.getOrCreate(name).bytesOut.Add(n)
83
+}
84
+
85
+// UpdateLastSeen sets the last-seen timestamp for the given secret to now.
86
+func (s *ProxyStats) UpdateLastSeen(name string) {
87
+	s.getOrCreate(name).lastSeen.Store(time.Now())
88
+}
89
+
90
+// StatsResponse is the JSON response for the stats endpoint.
91
+type StatsResponse struct {
92
+	StartedAt        time.Time                `json:"started_at"`
93
+	UptimeSeconds    int64                    `json:"uptime_seconds"`
94
+	TotalConnections int64                    `json:"total_connections"`
95
+	Users            map[string]UserStatsJSON `json:"users"`
96
+}
97
+
98
+// UserStatsJSON is the per-user portion of the stats JSON response.
99
+type UserStatsJSON struct {
100
+	Connections int64      `json:"connections"`
101
+	BytesIn     int64      `json:"bytes_in"`
102
+	BytesOut    int64      `json:"bytes_out"`
103
+	LastSeen    *time.Time `json:"last_seen"`
104
+}
105
+
106
+func (s *ProxyStats) ServeHTTP(w http.ResponseWriter, r *http.Request) {
107
+	s.mu.RLock()
108
+	defer s.mu.RUnlock()
109
+
110
+	var totalConns int64
111
+
112
+	users := make(map[string]UserStatsJSON, len(s.users))
113
+
114
+	for name, st := range s.users {
115
+		conns := st.connections.Load()
116
+		totalConns += conns
117
+
118
+		lastSeen := st.lastSeen.Load().(time.Time) //nolint: forcetypeassert
119
+		var lastSeenPtr *time.Time
120
+		if !lastSeen.IsZero() {
121
+			lastSeenPtr = &lastSeen
122
+		}
123
+
124
+		users[name] = UserStatsJSON{
125
+			Connections: conns,
126
+			BytesIn:     st.bytesIn.Load(),
127
+			BytesOut:    st.bytesOut.Load(),
128
+			LastSeen:    lastSeenPtr,
129
+		}
130
+	}
131
+
132
+	resp := StatsResponse{
133
+		StartedAt:        s.startedAt,
134
+		UptimeSeconds:    int64(time.Since(s.startedAt).Seconds()),
135
+		TotalConnections: totalConns,
136
+		Users:            users,
137
+	}
138
+
139
+	w.Header().Set("Content-Type", "application/json")
140
+
141
+	if err := json.NewEncoder(w).Encode(resp); err != nil {
142
+		http.Error(w, err.Error(), http.StatusInternalServerError)
143
+	}
144
+}
145
+
146
+// StartServer starts an HTTP server for the stats API in a background goroutine.
147
+// The server is shut down when ctx is cancelled.
148
+func (s *ProxyStats) StartServer(ctx context.Context, bindTo string, logger Logger) {
149
+	mux := http.NewServeMux()
150
+	mux.Handle("/stats", s)
151
+
152
+	srv := &http.Server{
153
+		Addr:    bindTo,
154
+		Handler: mux,
155
+	}
156
+
157
+	ln, err := net.Listen("tcp", bindTo)
158
+	if err != nil {
159
+		logger.WarningError("cannot start stats API listener", err)
160
+		return
161
+	}
162
+
163
+	go func() {
164
+		if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
165
+			logger.WarningError("stats API server error", err)
166
+		}
167
+	}()
168
+
169
+	go func() {
170
+		<-ctx.Done()
171
+
172
+		shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) //nolint: mnd
173
+		defer cancel()
174
+
175
+		srv.Shutdown(shutdownCtx) //nolint: errcheck
176
+	}()
177
+
178
+	logger.BindStr("bind", bindTo).Info("Stats API server started")
179
+}

+ 183
- 0
mtglib/proxy_stats_test.go Wyświetl plik

1
+package mtglib
2
+
3
+import (
4
+	"encoding/json"
5
+	"net/http"
6
+	"net/http/httptest"
7
+	"testing"
8
+	"time"
9
+
10
+	"github.com/stretchr/testify/assert"
11
+	"github.com/stretchr/testify/require"
12
+)
13
+
14
+func TestNewProxyStats(t *testing.T) {
15
+	t.Parallel()
16
+
17
+	stats := NewProxyStats()
18
+	assert.NotNil(t, stats)
19
+	assert.NotNil(t, stats.users)
20
+	assert.False(t, stats.startedAt.IsZero())
21
+}
22
+
23
+func TestPreRegister(t *testing.T) {
24
+	t.Parallel()
25
+
26
+	stats := NewProxyStats()
27
+	stats.PreRegister("alice")
28
+	stats.PreRegister("bob")
29
+
30
+	stats.mu.RLock()
31
+	defer stats.mu.RUnlock()
32
+
33
+	assert.Contains(t, stats.users, "alice")
34
+	assert.Contains(t, stats.users, "bob")
35
+	assert.Equal(t, int64(0), stats.users["alice"].connections.Load())
36
+}
37
+
38
+func TestOnConnectDisconnect(t *testing.T) {
39
+	t.Parallel()
40
+
41
+	stats := NewProxyStats()
42
+	stats.PreRegister("alice")
43
+
44
+	stats.OnConnect("alice")
45
+	assert.Equal(t, int64(1), stats.users["alice"].connections.Load())
46
+
47
+	stats.OnConnect("alice")
48
+	assert.Equal(t, int64(2), stats.users["alice"].connections.Load())
49
+
50
+	stats.OnDisconnect("alice")
51
+	assert.Equal(t, int64(1), stats.users["alice"].connections.Load())
52
+
53
+	stats.OnDisconnect("alice")
54
+	assert.Equal(t, int64(0), stats.users["alice"].connections.Load())
55
+}
56
+
57
+func TestAddBytes(t *testing.T) {
58
+	t.Parallel()
59
+
60
+	stats := NewProxyStats()
61
+	stats.PreRegister("alice")
62
+
63
+	stats.AddBytesIn("alice", 100)
64
+	stats.AddBytesIn("alice", 200)
65
+	stats.AddBytesOut("alice", 50)
66
+
67
+	st := stats.users["alice"]
68
+	assert.Equal(t, int64(300), st.bytesIn.Load())
69
+	assert.Equal(t, int64(50), st.bytesOut.Load())
70
+}
71
+
72
+func TestUpdateLastSeen(t *testing.T) {
73
+	t.Parallel()
74
+
75
+	stats := NewProxyStats()
76
+	stats.PreRegister("alice")
77
+
78
+	before := time.Now()
79
+	stats.UpdateLastSeen("alice")
80
+	after := time.Now()
81
+
82
+	lastSeen := stats.users["alice"].lastSeen.Load().(time.Time)
83
+	assert.False(t, lastSeen.Before(before))
84
+	assert.False(t, lastSeen.After(after))
85
+}
86
+
87
+func TestGetOrCreateLazy(t *testing.T) {
88
+	t.Parallel()
89
+
90
+	stats := NewProxyStats()
91
+
92
+	// getOrCreate should create a new entry on first access.
93
+	stats.OnConnect("new-user")
94
+	assert.Equal(t, int64(1), stats.users["new-user"].connections.Load())
95
+}
96
+
97
+func TestServeHTTPBasic(t *testing.T) {
98
+	t.Parallel()
99
+
100
+	stats := NewProxyStats()
101
+	stats.PreRegister("alice")
102
+	stats.PreRegister("bob")
103
+
104
+	stats.OnConnect("alice")
105
+	stats.OnConnect("alice")
106
+	stats.OnConnect("bob")
107
+	stats.AddBytesIn("alice", 1024)
108
+	stats.AddBytesOut("alice", 512)
109
+	stats.UpdateLastSeen("alice")
110
+
111
+	rec := httptest.NewRecorder()
112
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
113
+
114
+	stats.ServeHTTP(rec, req)
115
+
116
+	assert.Equal(t, http.StatusOK, rec.Code)
117
+	assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
118
+
119
+	var resp StatsResponse
120
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
121
+	require.NoError(t, err)
122
+
123
+	assert.Equal(t, int64(3), resp.TotalConnections)
124
+	assert.False(t, resp.StartedAt.IsZero())
125
+	assert.GreaterOrEqual(t, resp.UptimeSeconds, int64(0))
126
+
127
+	alice, ok := resp.Users["alice"]
128
+	require.True(t, ok)
129
+	assert.Equal(t, int64(2), alice.Connections)
130
+	assert.Equal(t, int64(1024), alice.BytesIn)
131
+	assert.Equal(t, int64(512), alice.BytesOut)
132
+	assert.NotNil(t, alice.LastSeen)
133
+
134
+	bob, ok := resp.Users["bob"]
135
+	require.True(t, ok)
136
+	assert.Equal(t, int64(1), bob.Connections)
137
+	assert.Equal(t, int64(0), bob.BytesIn)
138
+	assert.Equal(t, int64(0), bob.BytesOut)
139
+	assert.Nil(t, bob.LastSeen)
140
+}
141
+
142
+func TestServeHTTPEmpty(t *testing.T) {
143
+	t.Parallel()
144
+
145
+	stats := NewProxyStats()
146
+
147
+	rec := httptest.NewRecorder()
148
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
149
+
150
+	stats.ServeHTTP(rec, req)
151
+
152
+	assert.Equal(t, http.StatusOK, rec.Code)
153
+
154
+	var resp StatsResponse
155
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
156
+	require.NoError(t, err)
157
+
158
+	assert.Empty(t, resp.Users)
159
+	assert.Equal(t, int64(0), resp.TotalConnections)
160
+}
161
+
162
+func TestServeHTTPLastSeenZeroIsNull(t *testing.T) {
163
+	t.Parallel()
164
+
165
+	stats := NewProxyStats()
166
+	stats.PreRegister("alice")
167
+
168
+	rec := httptest.NewRecorder()
169
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
170
+
171
+	stats.ServeHTTP(rec, req)
172
+
173
+	var raw map[string]json.RawMessage
174
+	require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &raw))
175
+
176
+	var users map[string]json.RawMessage
177
+	require.NoError(t, json.Unmarshal(raw["users"], &users))
178
+
179
+	var aliceRaw map[string]json.RawMessage
180
+	require.NoError(t, json.Unmarshal(users["alice"], &aliceRaw))
181
+
182
+	assert.Equal(t, "null", string(aliceRaw["last_seen"]))
183
+}

+ 9
- 7
mtglib/stream_context.go Wyświetl plik

11
 )
11
 )
12
 
12
 
13
 type streamContext struct {
13
 type streamContext struct {
14
-	ctx          context.Context
15
-	ctxCancel    context.CancelFunc
16
-	clientConn   essentials.Conn
17
-	telegramConn essentials.Conn
18
-	streamID     string
19
-	dc           int
20
-	logger       Logger
14
+	ctx              context.Context
15
+	ctxCancel        context.CancelFunc
16
+	clientConn       essentials.Conn
17
+	telegramConn     essentials.Conn
18
+	streamID         string
19
+	dc               int
20
+	matchedSecretKey []byte
21
+	secretName       string
22
+	logger           Logger
21
 }
23
 }
22
 
24
 
23
 func (s *streamContext) Deadline() (time.Time, bool) {
25
 func (s *streamContext) Deadline() (time.Time, bool) {

Ładowanie…
Anuluj
Zapisz