|
|
@@ -6,6 +6,7 @@ import (
|
|
6
|
6
|
"fmt"
|
|
7
|
7
|
"io"
|
|
8
|
8
|
"net"
|
|
|
9
|
+ "sync/atomic"
|
|
9
|
10
|
"time"
|
|
10
|
11
|
|
|
11
|
12
|
"github.com/9seconds/mtg/v2/essentials"
|
|
|
@@ -97,20 +98,67 @@ func newConnProxyProtocol(source, target essentials.Conn) *connProxyProtocol {
|
|
97
|
98
|
}
|
|
98
|
99
|
}
|
|
99
|
100
|
|
|
|
101
|
+// idleTracker is a shared idle tracker for a pair of relay connections.
|
|
|
102
|
+// Both directions update the same timestamp so that activity in one direction
|
|
|
103
|
+// prevents the other (idle) direction from timing out.
|
|
|
104
|
+type idleTracker struct {
|
|
|
105
|
+ lastActive atomic.Int64 // unix nanos
|
|
|
106
|
+ timeout time.Duration
|
|
|
107
|
+}
|
|
|
108
|
+
|
|
|
109
|
+func newIdleTracker(timeout time.Duration) *idleTracker {
|
|
|
110
|
+ t := &idleTracker{timeout: timeout}
|
|
|
111
|
+ t.touch()
|
|
|
112
|
+
|
|
|
113
|
+ return t
|
|
|
114
|
+}
|
|
|
115
|
+
|
|
|
116
|
+func (t *idleTracker) touch() {
|
|
|
117
|
+ t.lastActive.Store(time.Now().UnixNano())
|
|
|
118
|
+}
|
|
|
119
|
+
|
|
|
120
|
+func (t *idleTracker) isIdle() bool {
|
|
|
121
|
+ last := time.Unix(0, t.lastActive.Load())
|
|
|
122
|
+
|
|
|
123
|
+ return time.Since(last) >= t.timeout
|
|
|
124
|
+}
|
|
|
125
|
+
|
|
100
|
126
|
type connIdleTimeout struct {
|
|
101
|
127
|
essentials.Conn
|
|
102
|
128
|
|
|
103
|
|
- timeout time.Duration
|
|
|
129
|
+ tracker *idleTracker
|
|
104
|
130
|
}
|
|
105
|
131
|
|
|
106
|
132
|
func (c connIdleTimeout) Read(b []byte) (int, error) {
|
|
107
|
|
- c.SetReadDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
|
|
|
133
|
+ for {
|
|
|
134
|
+ c.SetReadDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
|
|
|
135
|
+
|
|
|
136
|
+ n, err := c.Conn.Read(b)
|
|
|
137
|
+ if n > 0 {
|
|
|
138
|
+ c.tracker.touch()
|
|
108
|
139
|
|
|
109
|
|
- return c.Conn.Read(b) //nolint: wrapcheck
|
|
|
140
|
+ return n, err //nolint: wrapcheck
|
|
|
141
|
+ }
|
|
|
142
|
+
|
|
|
143
|
+ if err != nil {
|
|
|
144
|
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() && !c.tracker.isIdle() { //nolint: errorlint
|
|
|
145
|
+ continue
|
|
|
146
|
+ }
|
|
|
147
|
+
|
|
|
148
|
+ return 0, err //nolint: wrapcheck
|
|
|
149
|
+ }
|
|
|
150
|
+
|
|
|
151
|
+ return 0, nil
|
|
|
152
|
+ }
|
|
110
|
153
|
}
|
|
111
|
154
|
|
|
112
|
155
|
func (c connIdleTimeout) Write(b []byte) (int, error) {
|
|
113
|
|
- c.SetWriteDeadline(time.Now().Add(c.timeout)) //nolint: errcheck
|
|
|
156
|
+ c.SetWriteDeadline(time.Now().Add(c.tracker.timeout)) //nolint: errcheck
|
|
114
|
157
|
|
|
115
|
|
- return c.Conn.Write(b) //nolint: wrapcheck
|
|
|
158
|
+ n, err := c.Conn.Write(b)
|
|
|
159
|
+ if n > 0 {
|
|
|
160
|
+ c.tracker.touch()
|
|
|
161
|
+ }
|
|
|
162
|
+
|
|
|
163
|
+ return n, err //nolint: wrapcheck
|
|
116
|
164
|
}
|