Browse Source

Merge pull request #424 from dolonet/fix/relay-idle-timeout-shared-tracker

fix: use shared idle tracker for relay connections
tags/v2.2.6^2^2
Sergei Arkhipov 1 month ago
parent
commit
b926a0590c
No account linked to committer's email address
3 changed files with 212 additions and 9 deletions
  1. 53
    5
      mtglib/conns.go
  2. 151
    0
      mtglib/conns_internal_test.go
  3. 8
    4
      mtglib/proxy.go

+ 53
- 5
mtglib/conns.go View File

@@ -6,6 +6,7 @@ import (
6 6
 	"fmt"
7 7
 	"io"
8 8
 	"net"
9
+	"sync/atomic"
9 10
 	"time"
10 11
 
11 12
 	"github.com/9seconds/mtg/v2/essentials"
@@ -97,20 +98,67 @@ func newConnProxyProtocol(source, target essentials.Conn) *connProxyProtocol {
97 98
 	}
98 99
 }
99 100
 
101
+// idleTracker is a shared idle tracker for a pair of relay connections.
102
+// Both directions update the same timestamp so that activity in one direction
103
+// prevents the other (idle) direction from timing out.
104
+type idleTracker struct {
105
+	lastActive atomic.Int64 // unix nanos
106
+	timeout    time.Duration
107
+}
108
+
109
+func newIdleTracker(timeout time.Duration) *idleTracker {
110
+	t := &idleTracker{timeout: timeout}
111
+	t.touch()
112
+
113
+	return t
114
+}
115
+
116
+func (t *idleTracker) touch() {
117
+	t.lastActive.Store(time.Now().UnixNano())
118
+}
119
+
120
+func (t *idleTracker) isIdle() bool {
121
+	last := time.Unix(0, t.lastActive.Load())
122
+
123
+	return time.Since(last) >= t.timeout
124
+}
125
+
100 126
 type connIdleTimeout struct {
101 127
 	essentials.Conn
102 128
 
103
-	timeout time.Duration
129
+	tracker *idleTracker
104 130
 }
105 131
 
106 132
 func (c connIdleTimeout) Read(b []byte) (int, error) {
107
-	c.SetReadDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
133
+	for {
134
+		c.SetReadDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
135
+
136
+		n, err := c.Conn.Read(b)
137
+		if n > 0 {
138
+			c.tracker.touch()
108 139
 
109
-	return c.Conn.Read(b) //nolint: wrapcheck
140
+			return n, err //nolint: wrapcheck
141
+		}
142
+
143
+		if err != nil {
144
+			if netErr, ok := err.(net.Error); ok && netErr.Timeout() && !c.tracker.isIdle() { //nolint: errorlint
145
+				continue
146
+			}
147
+
148
+			return 0, err //nolint: wrapcheck
149
+		}
150
+
151
+		return 0, nil
152
+	}
110 153
 }
111 154
 
112 155
 func (c connIdleTimeout) Write(b []byte) (int, error) {
113
-	c.SetWriteDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
156
+	c.SetWriteDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
114 157
 
115
-	return c.Conn.Write(b) //nolint: wrapcheck
158
+	n, err := c.Conn.Write(b)
159
+	if n > 0 {
160
+		c.tracker.touch()
161
+	}
162
+
163
+	return n, err //nolint: wrapcheck
116 164
 }

+ 151
- 0
mtglib/conns_internal_test.go View File

@@ -16,6 +16,12 @@ import (
16 16
 	"github.com/stretchr/testify/suite"
17 17
 )
18 18
 
19
+type netTimeoutError struct{}
20
+
21
+func (e netTimeoutError) Error() string   { return "i/o timeout" }
22
+func (e netTimeoutError) Timeout() bool   { return true }
23
+func (e netTimeoutError) Temporary() bool { return true }
24
+
19 25
 type ConnRewindBaseConn struct {
20 26
 	testlib.EssentialsConnMock
21 27
 
@@ -291,6 +297,141 @@ func (suite *ConnProxyProtocolTestSuite) TearDownTest() {
291 297
 	suite.targetConnMock.AssertExpectations(suite.T())
292 298
 }
293 299
 
300
+type IdleTrackerTestSuite struct {
301
+	suite.Suite
302
+}
303
+
304
+func (suite *IdleTrackerTestSuite) TestNewNotIdle() {
305
+	tracker := newIdleTracker(time.Second)
306
+	suite.False(tracker.isIdle())
307
+}
308
+
309
+func (suite *IdleTrackerTestSuite) TestIdleAfterTimeout() {
310
+	tracker := newIdleTracker(10 * time.Millisecond)
311
+	time.Sleep(20 * time.Millisecond)
312
+
313
+	suite.True(tracker.isIdle())
314
+}
315
+
316
+func (suite *IdleTrackerTestSuite) TestTouchResetsIdle() {
317
+	tracker := newIdleTracker(50 * time.Millisecond)
318
+	time.Sleep(30 * time.Millisecond)
319
+
320
+	tracker.touch()
321
+
322
+	suite.False(tracker.isIdle())
323
+}
324
+
325
+type ConnIdleTimeoutTestSuite struct {
326
+	suite.Suite
327
+
328
+	connMock *testlib.EssentialsConnMock
329
+	tracker  *idleTracker
330
+	conn     connIdleTimeout
331
+}
332
+
333
+func (suite *ConnIdleTimeoutTestSuite) SetupTest() {
334
+	suite.connMock = &testlib.EssentialsConnMock{}
335
+	suite.tracker = newIdleTracker(time.Second)
336
+	suite.conn = connIdleTimeout{
337
+		Conn:    suite.connMock,
338
+		tracker: suite.tracker,
339
+	}
340
+}
341
+
342
+func (suite *ConnIdleTimeoutTestSuite) TearDownTest() {
343
+	suite.connMock.AssertExpectations(suite.T())
344
+}
345
+
346
+func (suite *ConnIdleTimeoutTestSuite) TestReadOk() {
347
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
348
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
349
+
350
+	n, err := suite.conn.Read(make([]byte, 10))
351
+	suite.NoError(err)
352
+	suite.Equal(5, n)
353
+}
354
+
355
+func (suite *ConnIdleTimeoutTestSuite) TestReadNonTimeoutErr() {
356
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
357
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
358
+
359
+	n, err := suite.conn.Read(make([]byte, 10))
360
+	suite.True(errors.Is(err, io.EOF))
361
+	suite.Equal(0, n)
362
+}
363
+
364
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutRetriesWhenNotIdle() {
365
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
366
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
367
+	suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
368
+
369
+	n, err := suite.conn.Read(make([]byte, 10))
370
+	suite.NoError(err)
371
+	suite.Equal(5, n)
372
+}
373
+
374
+func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutClosesWhenIdle() {
375
+	suite.tracker = newIdleTracker(time.Millisecond)
376
+	suite.conn = connIdleTimeout{
377
+		Conn:    suite.connMock,
378
+		tracker: suite.tracker,
379
+	}
380
+
381
+	time.Sleep(5 * time.Millisecond)
382
+
383
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
384
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
385
+
386
+	n, err := suite.conn.Read(make([]byte, 10))
387
+	suite.Equal(0, n)
388
+
389
+	netErr, ok := err.(net.Error) //nolint: errorlint
390
+	suite.True(ok)
391
+	suite.True(netErr.Timeout())
392
+}
393
+
394
+func (suite *ConnIdleTimeoutTestSuite) TestSharedTrackerPreventsFalseTimeout() {
395
+	connMock2 := &testlib.EssentialsConnMock{}
396
+	conn2 := connIdleTimeout{
397
+		Conn:    connMock2,
398
+		tracker: suite.tracker,
399
+	}
400
+
401
+	connMock2.On("SetWriteDeadline", mock.Anything).Return(nil)
402
+	connMock2.On("Write", mock.Anything).Once().Return(5, nil)
403
+
404
+	_, _ = conn2.Write(make([]byte, 5))
405
+
406
+	suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
407
+	suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
408
+	suite.connMock.On("Read", mock.Anything).Once().Return(3, nil)
409
+
410
+	n, err := suite.conn.Read(make([]byte, 10))
411
+	suite.NoError(err)
412
+	suite.Equal(3, n)
413
+
414
+	connMock2.AssertExpectations(suite.T())
415
+}
416
+
417
+func (suite *ConnIdleTimeoutTestSuite) TestWriteOk() {
418
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
419
+	suite.connMock.On("Write", mock.Anything).Once().Return(5, nil)
420
+
421
+	n, err := suite.conn.Write(make([]byte, 5))
422
+	suite.NoError(err)
423
+	suite.Equal(5, n)
424
+}
425
+
426
+func (suite *ConnIdleTimeoutTestSuite) TestWriteErr() {
427
+	suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
428
+	suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
429
+
430
+	n, err := suite.conn.Write(make([]byte, 5))
431
+	suite.True(errors.Is(err, io.EOF))
432
+	suite.Equal(0, n)
433
+}
434
+
294 435
 func TestConnTraffic(t *testing.T) {
295 436
 	t.Parallel()
296 437
 	suite.Run(t, &ConnTrafficTestSuite{})
@@ -305,3 +446,13 @@ func TestConnProxyProtocol(t *testing.T) {
305 446
 	t.Parallel()
306 447
 	suite.Run(t, &ConnProxyProtocolTestSuite{})
307 448
 }
449
+
450
+func TestIdleTracker(t *testing.T) {
451
+	t.Parallel()
452
+	suite.Run(t, &IdleTrackerTestSuite{})
453
+}
454
+
455
+func TestConnIdleTimeout(t *testing.T) {
456
+	t.Parallel()
457
+	suite.Run(t, &ConnIdleTimeoutTestSuite{})
458
+}

+ 8
- 4
mtglib/proxy.go View File

@@ -102,11 +102,13 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
102 102
 		return
103 103
 	}
104 104
 
105
+	tracker := newIdleTracker(p.idleTimeout)
106
+
105 107
 	relay.Relay(
106 108
 		ctx,
107 109
 		ctx.logger.Named("relay"),
108
-		connIdleTimeout{Conn: ctx.telegramConn, timeout: p.idleTimeout},
109
-		connIdleTimeout{Conn: ctx.clientConn, timeout: p.idleTimeout},
110
+		connIdleTimeout{Conn: ctx.telegramConn, tracker: tracker},
111
+		connIdleTimeout{Conn: ctx.clientConn, tracker: tracker},
110 112
 	)
111 113
 }
112 114
 
@@ -305,11 +307,13 @@ func (p *Proxy) doDomainFronting(ctx *streamContext, conn *connRewind) {
305 307
 		stream:   p.eventStream,
306 308
 	}
307 309
 
310
+	tracker := newIdleTracker(p.idleTimeout)
311
+
308 312
 	relay.Relay(
309 313
 		ctx,
310 314
 		ctx.logger.Named("domain-fronting"),
311
-		connIdleTimeout{Conn: frontConn, timeout: p.idleTimeout},
312
-		connIdleTimeout{Conn: conn, timeout: p.idleTimeout},
315
+		connIdleTimeout{Conn: frontConn, tracker: tracker},
316
+		connIdleTimeout{Conn: conn, tracker: tracker},
313 317
 	)
314 318
 }
315 319
 

Loading…
Cancel
Save