Pārlūkot izejas kodu

Add per-user connection throttling with fair-share algorithm

When total connections exceed a configurable limit, a background
goroutine (every 5s by default) computes per-user caps using a
fair-share algorithm: small users keep their connections, remaining
budget is split equally among heavy users. New connections from
over-cap users are rejected; existing connections are not killed.

Config:
  [throttle]
  max-connections = 5000
  check-interval = "5s"

Stats API response now includes throttle state with active caps.
pull/434/head
Alexey Dolotov 1 mēnesi atpakaļ
vecāks
revīzija
5dc9ba353a

+ 3
- 0
internal/cli/run_proxy.go Parādīt failu

@@ -271,6 +271,9 @@ func runProxy(conf *config.Config, version string) error { //nolint: funlen
271 271
 		DoppelGangerDRS:     conf.Defense.Doppelganger.DRS.Get(false),
272 272
 
273 273
 		APIBindTo: conf.APIBindTo.Get(""),
274
+
275
+		ThrottleMaxConnections: conf.Throttle.MaxConnections.Get(0),
276
+		ThrottleCheckInterval:  conf.Throttle.CheckInterval.Get(5 * time.Second),
274 277
 	}
275 278
 
276 279
 	proxy, err := mtglib.NewProxy(opts)

+ 5
- 1
internal/config/config.go Parādīt failu

@@ -70,7 +70,11 @@ type Config struct {
70 70
 		Proxies []TypeProxyURL `json:"proxies"`
71 71
 	} `json:"network"`
72 72
 	APIBindTo TypeHostPort `json:"apiBindTo"`
73
-	Stats     struct {
73
+	Throttle  struct {
74
+		MaxConnections TypeConcurrency `json:"maxConnections"`
75
+		CheckInterval  TypeDuration    `json:"checkInterval"`
76
+	} `json:"throttle"`
77
+	Stats struct {
74 78
 		StatsD struct {
75 79
 			Optional
76 80
 

+ 5
- 1
internal/config/parse.go Parādīt failu

@@ -65,7 +65,11 @@ type tomlConfig struct {
65 65
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`
66 66
 	} `toml:"network" json:"network,omitempty"`
67 67
 	APIBindTo string `toml:"api-bind-to" json:"apiBindTo,omitempty"`
68
-	Stats     struct {
68
+	Throttle  struct {
69
+		MaxConnections uint   `toml:"max-connections" json:"maxConnections,omitempty"`
70
+		CheckInterval  string `toml:"check-interval" json:"checkInterval,omitempty"`
71
+	} `toml:"throttle" json:"throttle,omitempty"`
72
+	Stats struct {
69 73
 		StatsD struct {
70 74
 			Enabled      bool   `toml:"enabled" json:"enabled,omitempty"`
71 75
 			Address      string `toml:"address" json:"address,omitempty"`

+ 19
- 0
mtglib/events.go Parādīt failu

@@ -93,6 +93,14 @@ type EventReplayAttack struct {
93 93
 	eventBase
94 94
 }
95 95
 
96
+// EventThrottled is emitted when a connection is rejected because the
97
+// per-user connection cap has been reached.
98
+type EventThrottled struct {
99
+	eventBase
100
+
101
+	SecretName string
102
+}
103
+
96 104
 // EventIPListSize is emitted when mtg updates a contents of the ip lists:
97 105
 // allowlist or blocklist.
98 106
 type EventIPListSize struct {
@@ -200,6 +208,17 @@ func NewEventReplayAttack(streamID string) EventReplayAttack {
200 208
 	}
201 209
 }
202 210
 
211
+// NewEventThrottled creates a new EventThrottled event.
212
+func NewEventThrottled(streamID, secretName string) EventThrottled {
213
+	return EventThrottled{
214
+		eventBase: eventBase{
215
+			timestamp: time.Now(),
216
+			streamID:  streamID,
217
+		},
218
+		SecretName: secretName,
219
+	}
220
+}
221
+
203 222
 // NewEventIPListSize creates a new EventIPListSize event.
204 223
 func NewEventIPListSize(size int, isBlockList bool) EventIPListSize {
205 224
 	return EventIPListSize{

+ 12
- 0
mtglib/proxy.go Parādīt failu

@@ -87,6 +87,13 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
87 87
 		return
88 88
 	}
89 89
 
90
+	if !p.stats.CanConnect(ctx.secretName) {
91
+		ctx.logger.Info("connection throttled")
92
+		p.eventStream.Send(ctx, NewEventThrottled(ctx.streamID, ctx.secretName))
93
+
94
+		return
95
+	}
96
+
90 97
 	p.stats.OnConnect(ctx.secretName)
91 98
 	p.stats.UpdateLastSeen(ctx.secretName)
92 99
 
@@ -381,6 +388,11 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) {
381 388
 		stats.StartServer(ctx, opts.APIBindTo, logger)
382 389
 	}
383 390
 
391
+	if opts.ThrottleMaxConnections > 0 {
392
+		stats.SetThrottle(int64(opts.ThrottleMaxConnections), opts.getThrottleCheckInterval())
393
+		stats.startThrottleLoop(ctx, logger)
394
+	}
395
+
384 396
 	proxy := &Proxy{
385 397
 		ctx:                      ctx,
386 398
 		ctxCancel:                cancel,

+ 22
- 0
mtglib/proxy_opts.go Parādīt failu

@@ -175,6 +175,20 @@ type ProxyOpts struct {
175 175
 	//
176 176
 	// This is an optional setting.
177 177
 	APIBindTo string
178
+
179
+	// ThrottleMaxConnections is the total connection limit. When total
180
+	// connections exceed this value, per-user caps are computed using
181
+	// a fair-share algorithm and new connections from over-cap users
182
+	// are rejected. 0 disables throttling.
183
+	//
184
+	// This is an optional setting.
185
+	ThrottleMaxConnections uint
186
+
187
+	// ThrottleCheckInterval is how often the throttle recomputes per-user
188
+	// caps. Defaults to 5 seconds.
189
+	//
190
+	// This is an optional setting.
191
+	ThrottleCheckInterval time.Duration
178 192
 }
179 193
 
180 194
 func (p ProxyOpts) valid() error {
@@ -269,6 +283,14 @@ func (p ProxyOpts) getIdleTimeout() time.Duration {
269 283
 	return p.IdleTimeout
270 284
 }
271 285
 
286
+func (p ProxyOpts) getThrottleCheckInterval() time.Duration {
287
+	if p.ThrottleCheckInterval == 0 {
288
+		return 5 * time.Second //nolint: mnd
289
+	}
290
+
291
+	return p.ThrottleCheckInterval
292
+}
293
+
272 294
 func (p ProxyOpts) getLogger(name string) Logger {
273 295
 	return p.Logger.Named(name)
274 296
 }

+ 157
- 0
mtglib/proxy_stats.go Parādīt failu

@@ -3,6 +3,7 @@ package mtglib
3 3
 import (
4 4
 	"context"
5 5
 	"encoding/json"
6
+	"fmt"
6 7
 	"net"
7 8
 	"net/http"
8 9
 	"sync"
@@ -23,6 +24,13 @@ type ProxyStats struct {
23 24
 	mu        sync.RWMutex
24 25
 	users     map[string]*secretStats
25 26
 	startedAt time.Time
27
+
28
+	// Throttle: per-user connection caps recomputed every throttleInterval.
29
+	throttleMu       sync.RWMutex
30
+	throttleCaps     map[string]int64
31
+	throttleLimit    int64
32
+	throttleInterval time.Duration
33
+	throttleActive   atomic.Bool
26 34
 }
27 35
 
28 36
 // NewProxyStats creates a new ProxyStats instance.
@@ -87,14 +95,140 @@ func (s *ProxyStats) UpdateLastSeen(name string) {
87 95
 	s.getOrCreate(name).lastSeen.Store(time.Now())
88 96
 }
89 97
 
98
+// SetThrottle configures connection throttling. Must be called before
99
+// startThrottleLoop and before any connections arrive.
100
+func (s *ProxyStats) SetThrottle(limit int64, interval time.Duration) {
101
+	s.throttleLimit = limit
102
+	s.throttleInterval = interval
103
+	s.throttleCaps = make(map[string]int64)
104
+}
105
+
106
+// CanConnect returns true if the user is allowed to open a new connection
107
+// under the current throttle caps. If throttling is not configured or the
108
+// user has no cap, it always returns true.
109
+func (s *ProxyStats) CanConnect(name string) bool {
110
+	if s.throttleLimit == 0 {
111
+		return true
112
+	}
113
+
114
+	s.throttleMu.RLock()
115
+	cap, hasCap := s.throttleCaps[name]
116
+	s.throttleMu.RUnlock()
117
+
118
+	if !hasCap {
119
+		return true
120
+	}
121
+
122
+	return s.getOrCreate(name).connections.Load() < cap
123
+}
124
+
125
+// startThrottleLoop runs a background goroutine that recomputes per-user
126
+// caps every throttleInterval.
127
+func (s *ProxyStats) startThrottleLoop(ctx context.Context, logger Logger) {
128
+	go func() {
129
+		ticker := time.NewTicker(s.throttleInterval)
130
+		defer ticker.Stop()
131
+
132
+		for {
133
+			select {
134
+			case <-ctx.Done():
135
+				return
136
+			case <-ticker.C:
137
+				s.recomputeCaps(logger)
138
+			}
139
+		}
140
+	}()
141
+
142
+	logger.BindStr("limit", fmt.Sprintf("%d", s.throttleLimit)).
143
+		BindStr("interval", s.throttleInterval.String()).
144
+		Info("throttle loop started")
145
+}
146
+
147
+func (s *ProxyStats) recomputeCaps(logger Logger) {
148
+	s.mu.RLock()
149
+	userConns := make(map[string]int64, len(s.users))
150
+	for name, st := range s.users {
151
+		userConns[name] = st.connections.Load()
152
+	}
153
+	s.mu.RUnlock()
154
+
155
+	caps := computeFairCaps(userConns, s.throttleLimit)
156
+	wasActive := s.throttleActive.Load()
157
+	nowActive := len(caps) > 0
158
+
159
+	s.throttleMu.Lock()
160
+	s.throttleCaps = caps
161
+	s.throttleActive.Store(nowActive)
162
+	s.throttleMu.Unlock()
163
+
164
+	if nowActive && !wasActive {
165
+		logger.Warning("throttle activated")
166
+	} else if !nowActive && wasActive {
167
+		logger.Info("throttle deactivated")
168
+	}
169
+}
170
+
171
+// computeFairCaps implements the fair-share algorithm. Users below the equal
172
+// share keep their connections; remaining budget is split equally among the
173
+// rest. Returns nil when no throttling is needed.
174
+func computeFairCaps(userConns map[string]int64, limit int64) map[string]int64 {
175
+	var total int64
176
+	for _, c := range userConns {
177
+		total += c
178
+	}
179
+
180
+	if total <= limit {
181
+		return nil
182
+	}
183
+
184
+	remaining := make(map[string]int64, len(userConns))
185
+	for k, v := range userConns {
186
+		remaining[k] = v
187
+	}
188
+
189
+	budget := limit
190
+	caps := make(map[string]int64)
191
+
192
+	for len(remaining) > 0 {
193
+		fairShare := budget / int64(len(remaining))
194
+		changed := false
195
+
196
+		for name, conns := range remaining {
197
+			if conns <= fairShare {
198
+				budget -= conns
199
+				delete(remaining, name)
200
+				changed = true
201
+			}
202
+		}
203
+
204
+		if !changed {
205
+			for name := range remaining {
206
+				caps[name] = fairShare
207
+			}
208
+
209
+			break
210
+		}
211
+	}
212
+
213
+	return caps
214
+}
215
+
90 216
 // StatsResponse is the JSON response for the stats endpoint.
91 217
 type StatsResponse struct {
92 218
 	StartedAt        time.Time                `json:"started_at"`
93 219
 	UptimeSeconds    int64                    `json:"uptime_seconds"`
94 220
 	TotalConnections int64                    `json:"total_connections"`
221
+	Throttle         *ThrottleJSON            `json:"throttle,omitempty"`
95 222
 	Users            map[string]UserStatsJSON `json:"users"`
96 223
 }
97 224
 
225
+// ThrottleJSON is the throttle portion of the stats JSON response.
226
+type ThrottleJSON struct {
227
+	Active bool             `json:"active"`
228
+	Limit  int64            `json:"limit"`
229
+	Caps   map[string]int64 `json:"caps,omitempty"`
230
+}
231
+
98 232
 // UserStatsJSON is the per-user portion of the stats JSON response.
99 233
 type UserStatsJSON struct {
100 234
 	Connections int64      `json:"connections"`
@@ -129,10 +263,33 @@ func (s *ProxyStats) ServeHTTP(w http.ResponseWriter, r *http.Request) {
129 263
 		}
130 264
 	}
131 265
 
266
+	var throttle *ThrottleJSON
267
+	if s.throttleLimit > 0 {
268
+		s.throttleMu.RLock()
269
+		active := s.throttleActive.Load()
270
+
271
+		var capsCopy map[string]int64
272
+		if len(s.throttleCaps) > 0 {
273
+			capsCopy = make(map[string]int64, len(s.throttleCaps))
274
+			for k, v := range s.throttleCaps {
275
+				capsCopy[k] = v
276
+			}
277
+		}
278
+
279
+		s.throttleMu.RUnlock()
280
+
281
+		throttle = &ThrottleJSON{
282
+			Active: active,
283
+			Limit:  s.throttleLimit,
284
+			Caps:   capsCopy,
285
+		}
286
+	}
287
+
132 288
 	resp := StatsResponse{
133 289
 		StartedAt:        s.startedAt,
134 290
 		UptimeSeconds:    int64(time.Since(s.startedAt).Seconds()),
135 291
 		TotalConnections: totalConns,
292
+		Throttle:         throttle,
136 293
 		Users:            users,
137 294
 	}
138 295
 

+ 158
- 0
mtglib/proxy_stats_test.go Parādīt failu

@@ -159,6 +159,164 @@ func TestServeHTTPEmpty(t *testing.T) {
159 159
 	assert.Equal(t, int64(0), resp.TotalConnections)
160 160
 }
161 161
 
162
+func TestComputeFairCaps_NoThrottle(t *testing.T) {
163
+	t.Parallel()
164
+
165
+	caps := computeFairCaps(map[string]int64{
166
+		"a": 10,
167
+		"b": 20,
168
+	}, 100)
169
+
170
+	assert.Nil(t, caps)
171
+}
172
+
173
+func TestComputeFairCaps_ExactLimit(t *testing.T) {
174
+	t.Parallel()
175
+
176
+	caps := computeFairCaps(map[string]int64{
177
+		"a": 50,
178
+		"b": 50,
179
+	}, 100)
180
+
181
+	assert.Nil(t, caps)
182
+}
183
+
184
+func TestComputeFairCaps_UserExample(t *testing.T) {
185
+	t.Parallel()
186
+
187
+	// The user's exact example: limit=100, users=[1, 1, 90, 110]
188
+	// Small users keep their 1+1=2, remaining budget=98, split among 2 → 49 each
189
+	caps := computeFairCaps(map[string]int64{
190
+		"a": 1,
191
+		"b": 1,
192
+		"c": 90,
193
+		"d": 110,
194
+	}, 100)
195
+
196
+	assert.Len(t, caps, 2)
197
+	assert.Equal(t, int64(49), caps["c"])
198
+	assert.Equal(t, int64(49), caps["d"])
199
+	// "a" and "b" should not appear in caps (they're under the fair share)
200
+	_, hasA := caps["a"]
201
+	_, hasB := caps["b"]
202
+	assert.False(t, hasA)
203
+	assert.False(t, hasB)
204
+}
205
+
206
+func TestComputeFairCaps_AllOverLimit(t *testing.T) {
207
+	t.Parallel()
208
+
209
+	caps := computeFairCaps(map[string]int64{
210
+		"a": 100,
211
+		"b": 100,
212
+	}, 50)
213
+
214
+	assert.Len(t, caps, 2)
215
+	assert.Equal(t, int64(25), caps["a"])
216
+	assert.Equal(t, int64(25), caps["b"])
217
+}
218
+
219
+func TestComputeFairCaps_SingleHeavyUser(t *testing.T) {
220
+	t.Parallel()
221
+
222
+	caps := computeFairCaps(map[string]int64{
223
+		"light": 5,
224
+		"heavy": 200,
225
+	}, 100)
226
+
227
+	// light(5) < fairShare(50), keeps 5. Budget = 95. Heavy capped to 95.
228
+	assert.Len(t, caps, 1)
229
+	assert.Equal(t, int64(95), caps["heavy"])
230
+}
231
+
232
+func TestCanConnect_NoThrottle(t *testing.T) {
233
+	t.Parallel()
234
+
235
+	stats := NewProxyStats()
236
+	// throttleLimit = 0 (default), so CanConnect always returns true
237
+	assert.True(t, stats.CanConnect("anyone"))
238
+}
239
+
240
+func TestCanConnect_WithCap(t *testing.T) {
241
+	t.Parallel()
242
+
243
+	stats := NewProxyStats()
244
+	stats.PreRegister("heavy")
245
+	stats.SetThrottle(100, 5*time.Second)
246
+
247
+	// Simulate 50 connections
248
+	for range 50 {
249
+		stats.OnConnect("heavy")
250
+	}
251
+
252
+	// Set cap to 50
253
+	stats.throttleMu.Lock()
254
+	stats.throttleCaps = map[string]int64{"heavy": 50}
255
+	stats.throttleActive.Store(true)
256
+	stats.throttleMu.Unlock()
257
+
258
+	// At exactly the cap → reject
259
+	assert.False(t, stats.CanConnect("heavy"))
260
+
261
+	// Disconnect one → allow
262
+	stats.OnDisconnect("heavy")
263
+	assert.True(t, stats.CanConnect("heavy"))
264
+}
265
+
266
+func TestCanConnect_NoCap(t *testing.T) {
267
+	t.Parallel()
268
+
269
+	stats := NewProxyStats()
270
+	stats.SetThrottle(100, 5*time.Second)
271
+
272
+	// User not in caps map → always allowed
273
+	assert.True(t, stats.CanConnect("uncapped-user"))
274
+}
275
+
276
+func TestServeHTTPThrottleInfo(t *testing.T) {
277
+	t.Parallel()
278
+
279
+	stats := NewProxyStats()
280
+	stats.PreRegister("alice")
281
+	stats.SetThrottle(100, 5*time.Second)
282
+
283
+	stats.throttleMu.Lock()
284
+	stats.throttleCaps = map[string]int64{"alice": 50}
285
+	stats.throttleActive.Store(true)
286
+	stats.throttleMu.Unlock()
287
+
288
+	rec := httptest.NewRecorder()
289
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
290
+
291
+	stats.ServeHTTP(rec, req)
292
+
293
+	var resp StatsResponse
294
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
295
+	require.NoError(t, err)
296
+
297
+	require.NotNil(t, resp.Throttle)
298
+	assert.True(t, resp.Throttle.Active)
299
+	assert.Equal(t, int64(100), resp.Throttle.Limit)
300
+	assert.Equal(t, int64(50), resp.Throttle.Caps["alice"])
301
+}
302
+
303
+func TestServeHTTPNoThrottle(t *testing.T) {
304
+	t.Parallel()
305
+
306
+	stats := NewProxyStats()
307
+
308
+	rec := httptest.NewRecorder()
309
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
310
+
311
+	stats.ServeHTTP(rec, req)
312
+
313
+	var resp StatsResponse
314
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
315
+	require.NoError(t, err)
316
+
317
+	assert.Nil(t, resp.Throttle)
318
+}
319
+
162 320
 func TestServeHTTPLastSeenZeroIsNull(t *testing.T) {
163 321
 	t.Parallel()
164 322
 

Notiek ielāde…
Atcelt
Saglabāt