| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- package mtglib
-
- import (
- "context"
- "encoding/json"
- "fmt"
- "net"
- "net/http"
- "sync"
- "sync/atomic"
- "time"
- )
-
- type secretStats struct {
- connections atomic.Int64
- bytesIn atomic.Int64
- bytesOut atomic.Int64
- lastSeen atomic.Value // stores time.Time
- }
-
- // ProxyStats tracks per-secret connection stats with atomic counters.
- // Thread-safe for concurrent access from proxy goroutines.
- type ProxyStats struct {
- mu sync.RWMutex
- users map[string]*secretStats
- startedAt time.Time
-
- // Throttle: per-user connection caps recomputed every throttleInterval.
- throttleMu sync.RWMutex
- throttleCaps map[string]int64
- throttleLimit int64
- throttleInterval time.Duration
- throttleActive atomic.Bool
- }
-
- // NewProxyStats creates a new ProxyStats instance.
- func NewProxyStats() *ProxyStats {
- return &ProxyStats{
- users: make(map[string]*secretStats),
- startedAt: time.Now(),
- }
- }
-
- func (s *ProxyStats) getOrCreate(name string) *secretStats {
- s.mu.RLock()
- st, ok := s.users[name]
- s.mu.RUnlock()
-
- if ok {
- return st
- }
-
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if st, ok = s.users[name]; ok {
- return st
- }
-
- st = &secretStats{}
- st.lastSeen.Store(time.Time{})
- s.users[name] = st
-
- return st
- }
-
- // PreRegister adds a secret name to the stats map so it appears in output
- // even if no connections have been made yet.
- func (s *ProxyStats) PreRegister(name string) {
- s.getOrCreate(name)
- }
-
- // OnConnect increments the active connection count for the given secret.
- func (s *ProxyStats) OnConnect(name string) {
- s.getOrCreate(name).connections.Add(1)
- }
-
- // OnDisconnect decrements the active connection count for the given secret.
- func (s *ProxyStats) OnDisconnect(name string) {
- s.getOrCreate(name).connections.Add(-1)
- }
-
- // AddBytesIn adds to the bytes-in counter for the given secret.
- func (s *ProxyStats) AddBytesIn(name string, n int64) {
- s.getOrCreate(name).bytesIn.Add(n)
- }
-
- // AddBytesOut adds to the bytes-out counter for the given secret.
- func (s *ProxyStats) AddBytesOut(name string, n int64) {
- s.getOrCreate(name).bytesOut.Add(n)
- }
-
- // UpdateLastSeen sets the last-seen timestamp for the given secret to now.
- func (s *ProxyStats) UpdateLastSeen(name string) {
- s.getOrCreate(name).lastSeen.Store(time.Now())
- }
-
- // SetThrottle configures connection throttling. Must be called before
- // startThrottleLoop and before any connections arrive.
- func (s *ProxyStats) SetThrottle(limit int64, interval time.Duration) {
- s.throttleLimit = limit
- s.throttleInterval = interval
- s.throttleCaps = make(map[string]int64)
- }
-
- // CanConnect returns true if the user is allowed to open a new connection
- // under the current throttle caps. If throttling is not configured or the
- // user has no cap, it always returns true.
- func (s *ProxyStats) CanConnect(name string) bool {
- if s.throttleLimit == 0 {
- return true
- }
-
- s.throttleMu.RLock()
- cap, hasCap := s.throttleCaps[name]
- s.throttleMu.RUnlock()
-
- if !hasCap {
- return true
- }
-
- return s.getOrCreate(name).connections.Load() < cap
- }
-
- // startThrottleLoop runs a background goroutine that recomputes per-user
- // caps every throttleInterval.
- func (s *ProxyStats) startThrottleLoop(ctx context.Context, logger Logger) {
- go func() {
- ticker := time.NewTicker(s.throttleInterval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ctx.Done():
- return
- case <-ticker.C:
- s.recomputeCaps(logger)
- }
- }
- }()
-
- logger.BindStr("limit", fmt.Sprintf("%d", s.throttleLimit)).
- BindStr("interval", s.throttleInterval.String()).
- Info("throttle loop started")
- }
-
- func (s *ProxyStats) recomputeCaps(logger Logger) {
- s.mu.RLock()
- userConns := make(map[string]int64, len(s.users))
- for name, st := range s.users {
- userConns[name] = st.connections.Load()
- }
- s.mu.RUnlock()
-
- caps := computeFairCaps(userConns, s.throttleLimit)
- wasActive := s.throttleActive.Load()
- nowActive := len(caps) > 0
-
- s.throttleMu.Lock()
- s.throttleCaps = caps
- s.throttleActive.Store(nowActive)
- s.throttleMu.Unlock()
-
- if nowActive && !wasActive {
- logger.Warning("throttle activated")
- } else if !nowActive && wasActive {
- logger.Info("throttle deactivated")
- }
- }
-
- // computeFairCaps implements the fair-share algorithm. Users below the equal
- // share keep their connections; remaining budget is split equally among the
- // rest. Returns nil when no throttling is needed.
- func computeFairCaps(userConns map[string]int64, limit int64) map[string]int64 {
- var total int64
- for _, c := range userConns {
- total += c
- }
-
- if total <= limit {
- return nil
- }
-
- remaining := make(map[string]int64, len(userConns))
- for k, v := range userConns {
- remaining[k] = v
- }
-
- budget := limit
- caps := make(map[string]int64)
-
- for len(remaining) > 0 {
- fairShare := budget / int64(len(remaining))
- changed := false
-
- for name, conns := range remaining {
- if conns <= fairShare {
- budget -= conns
- delete(remaining, name)
- changed = true
- }
- }
-
- if !changed {
- for name := range remaining {
- caps[name] = fairShare
- }
-
- break
- }
- }
-
- return caps
- }
-
- // StatsResponse is the JSON response for the stats endpoint.
- type StatsResponse struct {
- StartedAt time.Time `json:"started_at"`
- UptimeSeconds int64 `json:"uptime_seconds"`
- TotalConnections int64 `json:"total_connections"`
- Throttle *ThrottleJSON `json:"throttle,omitempty"`
- Users map[string]UserStatsJSON `json:"users"`
- }
-
- // ThrottleJSON is the throttle portion of the stats JSON response.
- type ThrottleJSON struct {
- Active bool `json:"active"`
- Limit int64 `json:"limit"`
- Caps map[string]int64 `json:"caps,omitempty"`
- }
-
- // UserStatsJSON is the per-user portion of the stats JSON response.
- type UserStatsJSON struct {
- Connections int64 `json:"connections"`
- BytesIn int64 `json:"bytes_in"`
- BytesOut int64 `json:"bytes_out"`
- LastSeen *time.Time `json:"last_seen"`
- }
-
- func (s *ProxyStats) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- var totalConns int64
-
- users := make(map[string]UserStatsJSON, len(s.users))
-
- for name, st := range s.users {
- conns := st.connections.Load()
- totalConns += conns
-
- lastSeen := st.lastSeen.Load().(time.Time) //nolint: forcetypeassert
- var lastSeenPtr *time.Time
- if !lastSeen.IsZero() {
- lastSeenPtr = &lastSeen
- }
-
- users[name] = UserStatsJSON{
- Connections: conns,
- BytesIn: st.bytesIn.Load(),
- BytesOut: st.bytesOut.Load(),
- LastSeen: lastSeenPtr,
- }
- }
-
- var throttle *ThrottleJSON
- if s.throttleLimit > 0 {
- s.throttleMu.RLock()
- active := s.throttleActive.Load()
-
- var capsCopy map[string]int64
- if len(s.throttleCaps) > 0 {
- capsCopy = make(map[string]int64, len(s.throttleCaps))
- for k, v := range s.throttleCaps {
- capsCopy[k] = v
- }
- }
-
- s.throttleMu.RUnlock()
-
- throttle = &ThrottleJSON{
- Active: active,
- Limit: s.throttleLimit,
- Caps: capsCopy,
- }
- }
-
- resp := StatsResponse{
- StartedAt: s.startedAt,
- UptimeSeconds: int64(time.Since(s.startedAt).Seconds()),
- TotalConnections: totalConns,
- Throttle: throttle,
- Users: users,
- }
-
- w.Header().Set("Content-Type", "application/json")
-
- if err := json.NewEncoder(w).Encode(resp); err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- }
- }
-
- // StartServer starts an HTTP server for the stats API in a background goroutine.
- // The server is shut down when ctx is cancelled.
- func (s *ProxyStats) StartServer(ctx context.Context, bindTo string, logger Logger) {
- mux := http.NewServeMux()
- mux.Handle("/stats", s)
-
- srv := &http.Server{
- Addr: bindTo,
- Handler: mux,
- }
-
- ln, err := net.Listen("tcp", bindTo)
- if err != nil {
- logger.WarningError("cannot start stats API listener", err)
- return
- }
-
- go func() {
- if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
- logger.WarningError("stats API server error", err)
- }
- }()
-
- go func() {
- <-ctx.Done()
-
- shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) //nolint: mnd
- defer cancel()
-
- srv.Shutdown(shutdownCtx) //nolint: errcheck
- }()
-
- logger.BindStr("bind", bindTo).Info("Stats API server started")
- }
|