Bladeren bron

Add per-user connection throttling with fair-share algorithm

When total connections exceed a configurable limit, a background
goroutine (every 5s by default) computes per-user caps using a
fair-share algorithm: small users keep their connections, remaining
budget is split equally among heavy users. New connections from
over-cap users are rejected; existing connections are not killed.

Config:
  [throttle]
  max-connections = 5000
  check-interval = "5s"

Stats API response now includes throttle state with active caps.
pull/434/head
Alexey Dolotov 1 maand geleden
bovenliggende
commit
5dc9ba353a
8 gewijzigde bestanden met toevoegingen van 381 en 2 verwijderingen
  1. 3
    0
      internal/cli/run_proxy.go
  2. 5
    1
      internal/config/config.go
  3. 5
    1
      internal/config/parse.go
  4. 19
    0
      mtglib/events.go
  5. 12
    0
      mtglib/proxy.go
  6. 22
    0
      mtglib/proxy_opts.go
  7. 157
    0
      mtglib/proxy_stats.go
  8. 158
    0
      mtglib/proxy_stats_test.go

+ 3
- 0
internal/cli/run_proxy.go Bestand weergeven

271
 		DoppelGangerDRS:     conf.Defense.Doppelganger.DRS.Get(false),
271
 		DoppelGangerDRS:     conf.Defense.Doppelganger.DRS.Get(false),
272
 
272
 
273
 		APIBindTo: conf.APIBindTo.Get(""),
273
 		APIBindTo: conf.APIBindTo.Get(""),
274
+
275
+		ThrottleMaxConnections: conf.Throttle.MaxConnections.Get(0),
276
+		ThrottleCheckInterval:  conf.Throttle.CheckInterval.Get(5 * time.Second),
274
 	}
277
 	}
275
 
278
 
276
 	proxy, err := mtglib.NewProxy(opts)
279
 	proxy, err := mtglib.NewProxy(opts)

+ 5
- 1
internal/config/config.go Bestand weergeven

70
 		Proxies []TypeProxyURL `json:"proxies"`
70
 		Proxies []TypeProxyURL `json:"proxies"`
71
 	} `json:"network"`
71
 	} `json:"network"`
72
 	APIBindTo TypeHostPort `json:"apiBindTo"`
72
 	APIBindTo TypeHostPort `json:"apiBindTo"`
73
-	Stats     struct {
73
+	Throttle  struct {
74
+		MaxConnections TypeConcurrency `json:"maxConnections"`
75
+		CheckInterval  TypeDuration    `json:"checkInterval"`
76
+	} `json:"throttle"`
77
+	Stats struct {
74
 		StatsD struct {
78
 		StatsD struct {
75
 			Optional
79
 			Optional
76
 
80
 

+ 5
- 1
internal/config/parse.go Bestand weergeven

65
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`
65
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`
66
 	} `toml:"network" json:"network,omitempty"`
66
 	} `toml:"network" json:"network,omitempty"`
67
 	APIBindTo string `toml:"api-bind-to" json:"apiBindTo,omitempty"`
67
 	APIBindTo string `toml:"api-bind-to" json:"apiBindTo,omitempty"`
68
-	Stats     struct {
68
+	Throttle  struct {
69
+		MaxConnections uint   `toml:"max-connections" json:"maxConnections,omitempty"`
70
+		CheckInterval  string `toml:"check-interval" json:"checkInterval,omitempty"`
71
+	} `toml:"throttle" json:"throttle,omitempty"`
72
+	Stats struct {
69
 		StatsD struct {
73
 		StatsD struct {
70
 			Enabled      bool   `toml:"enabled" json:"enabled,omitempty"`
74
 			Enabled      bool   `toml:"enabled" json:"enabled,omitempty"`
71
 			Address      string `toml:"address" json:"address,omitempty"`
75
 			Address      string `toml:"address" json:"address,omitempty"`

+ 19
- 0
mtglib/events.go Bestand weergeven

93
 	eventBase
93
 	eventBase
94
 }
94
 }
95
 
95
 
96
+// EventThrottled is emitted when a connection is rejected because the
97
+// per-user connection cap has been reached.
98
+type EventThrottled struct {
99
+	eventBase
100
+
101
+	SecretName string
102
+}
103
+
96
 // EventIPListSize is emitted when mtg updates a contents of the ip lists:
104
 // EventIPListSize is emitted when mtg updates a contents of the ip lists:
97
 // allowlist or blocklist.
105
 // allowlist or blocklist.
98
 type EventIPListSize struct {
106
 type EventIPListSize struct {
200
 	}
208
 	}
201
 }
209
 }
202
 
210
 
211
+// NewEventThrottled creates a new EventThrottled event.
212
+func NewEventThrottled(streamID, secretName string) EventThrottled {
213
+	return EventThrottled{
214
+		eventBase: eventBase{
215
+			timestamp: time.Now(),
216
+			streamID:  streamID,
217
+		},
218
+		SecretName: secretName,
219
+	}
220
+}
221
+
203
 // NewEventIPListSize creates a new EventIPListSize event.
222
 // NewEventIPListSize creates a new EventIPListSize event.
204
 func NewEventIPListSize(size int, isBlockList bool) EventIPListSize {
223
 func NewEventIPListSize(size int, isBlockList bool) EventIPListSize {
205
 	return EventIPListSize{
224
 	return EventIPListSize{

+ 12
- 0
mtglib/proxy.go Bestand weergeven

87
 		return
87
 		return
88
 	}
88
 	}
89
 
89
 
90
+	if !p.stats.CanConnect(ctx.secretName) {
91
+		ctx.logger.Info("connection throttled")
92
+		p.eventStream.Send(ctx, NewEventThrottled(ctx.streamID, ctx.secretName))
93
+
94
+		return
95
+	}
96
+
90
 	p.stats.OnConnect(ctx.secretName)
97
 	p.stats.OnConnect(ctx.secretName)
91
 	p.stats.UpdateLastSeen(ctx.secretName)
98
 	p.stats.UpdateLastSeen(ctx.secretName)
92
 
99
 
381
 		stats.StartServer(ctx, opts.APIBindTo, logger)
388
 		stats.StartServer(ctx, opts.APIBindTo, logger)
382
 	}
389
 	}
383
 
390
 
391
+	if opts.ThrottleMaxConnections > 0 {
392
+		stats.SetThrottle(int64(opts.ThrottleMaxConnections), opts.getThrottleCheckInterval())
393
+		stats.startThrottleLoop(ctx, logger)
394
+	}
395
+
384
 	proxy := &Proxy{
396
 	proxy := &Proxy{
385
 		ctx:                      ctx,
397
 		ctx:                      ctx,
386
 		ctxCancel:                cancel,
398
 		ctxCancel:                cancel,

+ 22
- 0
mtglib/proxy_opts.go Bestand weergeven

175
 	//
175
 	//
176
 	// This is an optional setting.
176
 	// This is an optional setting.
177
 	APIBindTo string
177
 	APIBindTo string
178
+
179
+	// ThrottleMaxConnections is the total connection limit. When total
180
+	// connections exceed this value, per-user caps are computed using
181
+	// a fair-share algorithm and new connections from over-cap users
182
+	// are rejected. 0 disables throttling.
183
+	//
184
+	// This is an optional setting.
185
+	ThrottleMaxConnections uint
186
+
187
+	// ThrottleCheckInterval is how often the throttle recomputes per-user
188
+	// caps. Defaults to 5 seconds.
189
+	//
190
+	// This is an optional setting.
191
+	ThrottleCheckInterval time.Duration
178
 }
192
 }
179
 
193
 
180
 func (p ProxyOpts) valid() error {
194
 func (p ProxyOpts) valid() error {
269
 	return p.IdleTimeout
283
 	return p.IdleTimeout
270
 }
284
 }
271
 
285
 
286
+func (p ProxyOpts) getThrottleCheckInterval() time.Duration {
287
+	if p.ThrottleCheckInterval == 0 {
288
+		return 5 * time.Second //nolint: mnd
289
+	}
290
+
291
+	return p.ThrottleCheckInterval
292
+}
293
+
272
 func (p ProxyOpts) getLogger(name string) Logger {
294
 func (p ProxyOpts) getLogger(name string) Logger {
273
 	return p.Logger.Named(name)
295
 	return p.Logger.Named(name)
274
 }
296
 }

+ 157
- 0
mtglib/proxy_stats.go Bestand weergeven

3
 import (
3
 import (
4
 	"context"
4
 	"context"
5
 	"encoding/json"
5
 	"encoding/json"
6
+	"fmt"
6
 	"net"
7
 	"net"
7
 	"net/http"
8
 	"net/http"
8
 	"sync"
9
 	"sync"
23
 	mu        sync.RWMutex
24
 	mu        sync.RWMutex
24
 	users     map[string]*secretStats
25
 	users     map[string]*secretStats
25
 	startedAt time.Time
26
 	startedAt time.Time
27
+
28
+	// Throttle: per-user connection caps recomputed every throttleInterval.
29
+	throttleMu       sync.RWMutex
30
+	throttleCaps     map[string]int64
31
+	throttleLimit    int64
32
+	throttleInterval time.Duration
33
+	throttleActive   atomic.Bool
26
 }
34
 }
27
 
35
 
28
 // NewProxyStats creates a new ProxyStats instance.
36
 // NewProxyStats creates a new ProxyStats instance.
87
 	s.getOrCreate(name).lastSeen.Store(time.Now())
95
 	s.getOrCreate(name).lastSeen.Store(time.Now())
88
 }
96
 }
89
 
97
 
98
+// SetThrottle configures connection throttling. Must be called before
99
+// startThrottleLoop and before any connections arrive.
100
+func (s *ProxyStats) SetThrottle(limit int64, interval time.Duration) {
101
+	s.throttleLimit = limit
102
+	s.throttleInterval = interval
103
+	s.throttleCaps = make(map[string]int64)
104
+}
105
+
106
+// CanConnect returns true if the user is allowed to open a new connection
107
+// under the current throttle caps. If throttling is not configured or the
108
+// user has no cap, it always returns true.
109
+func (s *ProxyStats) CanConnect(name string) bool {
110
+	if s.throttleLimit == 0 {
111
+		return true
112
+	}
113
+
114
+	s.throttleMu.RLock()
115
+	cap, hasCap := s.throttleCaps[name]
116
+	s.throttleMu.RUnlock()
117
+
118
+	if !hasCap {
119
+		return true
120
+	}
121
+
122
+	return s.getOrCreate(name).connections.Load() < cap
123
+}
124
+
125
+// startThrottleLoop runs a background goroutine that recomputes per-user
126
+// caps every throttleInterval.
127
+func (s *ProxyStats) startThrottleLoop(ctx context.Context, logger Logger) {
128
+	go func() {
129
+		ticker := time.NewTicker(s.throttleInterval)
130
+		defer ticker.Stop()
131
+
132
+		for {
133
+			select {
134
+			case <-ctx.Done():
135
+				return
136
+			case <-ticker.C:
137
+				s.recomputeCaps(logger)
138
+			}
139
+		}
140
+	}()
141
+
142
+	logger.BindStr("limit", fmt.Sprintf("%d", s.throttleLimit)).
143
+		BindStr("interval", s.throttleInterval.String()).
144
+		Info("throttle loop started")
145
+}
146
+
147
+func (s *ProxyStats) recomputeCaps(logger Logger) {
148
+	s.mu.RLock()
149
+	userConns := make(map[string]int64, len(s.users))
150
+	for name, st := range s.users {
151
+		userConns[name] = st.connections.Load()
152
+	}
153
+	s.mu.RUnlock()
154
+
155
+	caps := computeFairCaps(userConns, s.throttleLimit)
156
+	wasActive := s.throttleActive.Load()
157
+	nowActive := len(caps) > 0
158
+
159
+	s.throttleMu.Lock()
160
+	s.throttleCaps = caps
161
+	s.throttleActive.Store(nowActive)
162
+	s.throttleMu.Unlock()
163
+
164
+	if nowActive && !wasActive {
165
+		logger.Warning("throttle activated")
166
+	} else if !nowActive && wasActive {
167
+		logger.Info("throttle deactivated")
168
+	}
169
+}
170
+
171
+// computeFairCaps implements the fair-share algorithm. Users below the equal
172
+// share keep their connections; remaining budget is split equally among the
173
+// rest. Returns nil when no throttling is needed.
174
+func computeFairCaps(userConns map[string]int64, limit int64) map[string]int64 {
175
+	var total int64
176
+	for _, c := range userConns {
177
+		total += c
178
+	}
179
+
180
+	if total <= limit {
181
+		return nil
182
+	}
183
+
184
+	remaining := make(map[string]int64, len(userConns))
185
+	for k, v := range userConns {
186
+		remaining[k] = v
187
+	}
188
+
189
+	budget := limit
190
+	caps := make(map[string]int64)
191
+
192
+	for len(remaining) > 0 {
193
+		fairShare := budget / int64(len(remaining))
194
+		changed := false
195
+
196
+		for name, conns := range remaining {
197
+			if conns <= fairShare {
198
+				budget -= conns
199
+				delete(remaining, name)
200
+				changed = true
201
+			}
202
+		}
203
+
204
+		if !changed {
205
+			for name := range remaining {
206
+				caps[name] = fairShare
207
+			}
208
+
209
+			break
210
+		}
211
+	}
212
+
213
+	return caps
214
+}
215
+
90
 // StatsResponse is the JSON response for the stats endpoint.
216
 // StatsResponse is the JSON response for the stats endpoint.
91
 type StatsResponse struct {
217
 type StatsResponse struct {
92
 	StartedAt        time.Time                `json:"started_at"`
218
 	StartedAt        time.Time                `json:"started_at"`
93
 	UptimeSeconds    int64                    `json:"uptime_seconds"`
219
 	UptimeSeconds    int64                    `json:"uptime_seconds"`
94
 	TotalConnections int64                    `json:"total_connections"`
220
 	TotalConnections int64                    `json:"total_connections"`
221
+	Throttle         *ThrottleJSON            `json:"throttle,omitempty"`
95
 	Users            map[string]UserStatsJSON `json:"users"`
222
 	Users            map[string]UserStatsJSON `json:"users"`
96
 }
223
 }
97
 
224
 
225
+// ThrottleJSON is the throttle portion of the stats JSON response.
226
+type ThrottleJSON struct {
227
+	Active bool             `json:"active"`
228
+	Limit  int64            `json:"limit"`
229
+	Caps   map[string]int64 `json:"caps,omitempty"`
230
+}
231
+
98
 // UserStatsJSON is the per-user portion of the stats JSON response.
232
 // UserStatsJSON is the per-user portion of the stats JSON response.
99
 type UserStatsJSON struct {
233
 type UserStatsJSON struct {
100
 	Connections int64      `json:"connections"`
234
 	Connections int64      `json:"connections"`
129
 		}
263
 		}
130
 	}
264
 	}
131
 
265
 
266
+	var throttle *ThrottleJSON
267
+	if s.throttleLimit > 0 {
268
+		s.throttleMu.RLock()
269
+		active := s.throttleActive.Load()
270
+
271
+		var capsCopy map[string]int64
272
+		if len(s.throttleCaps) > 0 {
273
+			capsCopy = make(map[string]int64, len(s.throttleCaps))
274
+			for k, v := range s.throttleCaps {
275
+				capsCopy[k] = v
276
+			}
277
+		}
278
+
279
+		s.throttleMu.RUnlock()
280
+
281
+		throttle = &ThrottleJSON{
282
+			Active: active,
283
+			Limit:  s.throttleLimit,
284
+			Caps:   capsCopy,
285
+		}
286
+	}
287
+
132
 	resp := StatsResponse{
288
 	resp := StatsResponse{
133
 		StartedAt:        s.startedAt,
289
 		StartedAt:        s.startedAt,
134
 		UptimeSeconds:    int64(time.Since(s.startedAt).Seconds()),
290
 		UptimeSeconds:    int64(time.Since(s.startedAt).Seconds()),
135
 		TotalConnections: totalConns,
291
 		TotalConnections: totalConns,
292
+		Throttle:         throttle,
136
 		Users:            users,
293
 		Users:            users,
137
 	}
294
 	}
138
 
295
 

+ 158
- 0
mtglib/proxy_stats_test.go Bestand weergeven

159
 	assert.Equal(t, int64(0), resp.TotalConnections)
159
 	assert.Equal(t, int64(0), resp.TotalConnections)
160
 }
160
 }
161
 
161
 
162
+func TestComputeFairCaps_NoThrottle(t *testing.T) {
163
+	t.Parallel()
164
+
165
+	caps := computeFairCaps(map[string]int64{
166
+		"a": 10,
167
+		"b": 20,
168
+	}, 100)
169
+
170
+	assert.Nil(t, caps)
171
+}
172
+
173
+func TestComputeFairCaps_ExactLimit(t *testing.T) {
174
+	t.Parallel()
175
+
176
+	caps := computeFairCaps(map[string]int64{
177
+		"a": 50,
178
+		"b": 50,
179
+	}, 100)
180
+
181
+	assert.Nil(t, caps)
182
+}
183
+
184
+func TestComputeFairCaps_UserExample(t *testing.T) {
185
+	t.Parallel()
186
+
187
+	// The user's exact example: limit=100, users=[1, 1, 90, 110]
188
+	// Small users keep their 1+1=2, remaining budget=98, split among 2 → 49 each
189
+	caps := computeFairCaps(map[string]int64{
190
+		"a": 1,
191
+		"b": 1,
192
+		"c": 90,
193
+		"d": 110,
194
+	}, 100)
195
+
196
+	assert.Len(t, caps, 2)
197
+	assert.Equal(t, int64(49), caps["c"])
198
+	assert.Equal(t, int64(49), caps["d"])
199
+	// "a" and "b" should not appear in caps (they're under the fair share)
200
+	_, hasA := caps["a"]
201
+	_, hasB := caps["b"]
202
+	assert.False(t, hasA)
203
+	assert.False(t, hasB)
204
+}
205
+
206
+func TestComputeFairCaps_AllOverLimit(t *testing.T) {
207
+	t.Parallel()
208
+
209
+	caps := computeFairCaps(map[string]int64{
210
+		"a": 100,
211
+		"b": 100,
212
+	}, 50)
213
+
214
+	assert.Len(t, caps, 2)
215
+	assert.Equal(t, int64(25), caps["a"])
216
+	assert.Equal(t, int64(25), caps["b"])
217
+}
218
+
219
+func TestComputeFairCaps_SingleHeavyUser(t *testing.T) {
220
+	t.Parallel()
221
+
222
+	caps := computeFairCaps(map[string]int64{
223
+		"light": 5,
224
+		"heavy": 200,
225
+	}, 100)
226
+
227
+	// light(5) < fairShare(50), keeps 5. Budget = 95. Heavy capped to 95.
228
+	assert.Len(t, caps, 1)
229
+	assert.Equal(t, int64(95), caps["heavy"])
230
+}
231
+
232
+func TestCanConnect_NoThrottle(t *testing.T) {
233
+	t.Parallel()
234
+
235
+	stats := NewProxyStats()
236
+	// throttleLimit = 0 (default), so CanConnect always returns true
237
+	assert.True(t, stats.CanConnect("anyone"))
238
+}
239
+
240
+func TestCanConnect_WithCap(t *testing.T) {
241
+	t.Parallel()
242
+
243
+	stats := NewProxyStats()
244
+	stats.PreRegister("heavy")
245
+	stats.SetThrottle(100, 5*time.Second)
246
+
247
+	// Simulate 50 connections
248
+	for range 50 {
249
+		stats.OnConnect("heavy")
250
+	}
251
+
252
+	// Set cap to 50
253
+	stats.throttleMu.Lock()
254
+	stats.throttleCaps = map[string]int64{"heavy": 50}
255
+	stats.throttleActive.Store(true)
256
+	stats.throttleMu.Unlock()
257
+
258
+	// At exactly the cap → reject
259
+	assert.False(t, stats.CanConnect("heavy"))
260
+
261
+	// Disconnect one → allow
262
+	stats.OnDisconnect("heavy")
263
+	assert.True(t, stats.CanConnect("heavy"))
264
+}
265
+
266
+func TestCanConnect_NoCap(t *testing.T) {
267
+	t.Parallel()
268
+
269
+	stats := NewProxyStats()
270
+	stats.SetThrottle(100, 5*time.Second)
271
+
272
+	// User not in caps map → always allowed
273
+	assert.True(t, stats.CanConnect("uncapped-user"))
274
+}
275
+
276
+func TestServeHTTPThrottleInfo(t *testing.T) {
277
+	t.Parallel()
278
+
279
+	stats := NewProxyStats()
280
+	stats.PreRegister("alice")
281
+	stats.SetThrottle(100, 5*time.Second)
282
+
283
+	stats.throttleMu.Lock()
284
+	stats.throttleCaps = map[string]int64{"alice": 50}
285
+	stats.throttleActive.Store(true)
286
+	stats.throttleMu.Unlock()
287
+
288
+	rec := httptest.NewRecorder()
289
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
290
+
291
+	stats.ServeHTTP(rec, req)
292
+
293
+	var resp StatsResponse
294
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
295
+	require.NoError(t, err)
296
+
297
+	require.NotNil(t, resp.Throttle)
298
+	assert.True(t, resp.Throttle.Active)
299
+	assert.Equal(t, int64(100), resp.Throttle.Limit)
300
+	assert.Equal(t, int64(50), resp.Throttle.Caps["alice"])
301
+}
302
+
303
+func TestServeHTTPNoThrottle(t *testing.T) {
304
+	t.Parallel()
305
+
306
+	stats := NewProxyStats()
307
+
308
+	rec := httptest.NewRecorder()
309
+	req := httptest.NewRequest(http.MethodGet, "/stats", nil)
310
+
311
+	stats.ServeHTTP(rec, req)
312
+
313
+	var resp StatsResponse
314
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
315
+	require.NoError(t, err)
316
+
317
+	assert.Nil(t, resp.Throttle)
318
+}
319
+
162
 func TestServeHTTPLastSeenZeroIsNull(t *testing.T) {
320
 func TestServeHTTPLastSeenZeroIsNull(t *testing.T) {
163
 	t.Parallel()
321
 	t.Parallel()
164
 
322
 

Laden…
Annuleren
Opslaan