Bläddra i källkod

Merge remote-tracking branch 'origin/stable' into v2

tags/v2.2.8
9seconds 4 veckor sedan
förälder
incheckning
83a31e0458

+ 13
- 5
example.config.toml Visa fil

@@ -204,14 +204,22 @@ proxies = [
204 204
 # define a global timeout on establishing of network connections. idle
205 205
 # means a timeout on pumping data between sockset when nothing is
206 206
 # happening.
207
-#
208
-# please be noticed that handshakes have no timeouts intentionally. You can
209
-# find a reasoning here:
210
-# https://www.ndss-symposium.org/wp-content/uploads/2020/02/23087-paper.pdf
211 207
 [network.timeout]
212 208
 tcp = "5s"
213 209
 http = "10s"
214
-idle = "1m"
210
+idle = "5m"
211
+handshake = "10s"
212
+
213
+# this defines a configuration for TCP keep alives. Default values are taken
214
+# from Golang default behavior.
215
+[network.keep-alive]
216
+disabled = false
217
+# idle means a time period after which we start sending TCP Keep Alive probes
218
+idle = "15s"
219
+# interval is a period between 2 consecutive probes
220
+interval = "15s"
221
+# if we miss that many probes, a connection will be considered as a dead one.
222
+count = 9
215 223
 
216 224
 # mtg has to mimic real websites. It does not mean domain fronting, it also
217 225
 # means that traffic characteristics should be similar to real world traffic.

+ 2
- 2
go.mod Visa fil

@@ -4,7 +4,7 @@ go 1.26
4 4
 
5 5
 require (
6 6
 	github.com/OneOfOne/xxhash v1.2.8
7
-	github.com/alecthomas/kong v1.14.0
7
+	github.com/alecthomas/kong v1.15.0
8 8
 	github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b
9 9
 	github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
10 10
 	github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6
@@ -28,7 +28,7 @@ require (
28 28
 
29 29
 require (
30 30
 	github.com/beevik/ntp v1.5.0
31
-	github.com/ncruces/go-dns v1.3.2
31
+	github.com/ncruces/go-dns v1.3.3
32 32
 	github.com/pelletier/go-toml/v2 v2.3.0
33 33
 	github.com/pires/go-proxyproto v0.11.0
34 34
 	github.com/things-go/go-socks5 v0.1.0

+ 4
- 4
go.sum Visa fil

@@ -2,8 +2,8 @@ github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8
2 2
 github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q=
3 3
 github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
4 4
 github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
5
-github.com/alecthomas/kong v1.14.0 h1:gFgEUZWu2ZmZ+UhyZ1bDhuutbKN1nTtJTwh19Wsn21s=
6
-github.com/alecthomas/kong v1.14.0/go.mod h1:wrlbXem1CWqUV5Vbmss5ISYhsVPkBb1Yo7YKJghju2I=
5
+github.com/alecthomas/kong v1.15.0 h1:BVJstKbpO73zKpmIu+m/aLRrNmWwxXPIGTNin9VmLVI=
6
+github.com/alecthomas/kong v1.15.0/go.mod h1:wrlbXem1CWqUV5Vbmss5ISYhsVPkBb1Yo7YKJghju2I=
7 7
 github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
8 8
 github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
9 9
 github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0=
@@ -48,8 +48,8 @@ github.com/miekg/dns v1.1.51 h1:0+Xg7vObnhrz/4ZCZcZh7zPXlmU0aveS2HDBd0m0qSo=
48 48
 github.com/miekg/dns v1.1.51/go.mod h1:2Z9d3CP1LQWihRZUf29mQ19yDThaI4DAYzte2CaQW5c=
49 49
 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
50 50
 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
51
-github.com/ncruces/go-dns v1.3.2 h1:kBLuUZBgkQ4qF4WDXZRQ4rG0Gk6sLVJQ5tESkWrxUa0=
52
-github.com/ncruces/go-dns v1.3.2/go.mod h1:tuzixNY8PY/M7yUzcvRbUaeLs3ifIdydpi5H2bfRU+s=
51
+github.com/ncruces/go-dns v1.3.3 h1:59OV7XoJrTCoUMZjWRVs4GOjtntMTZqiQ5Mn+BT13hk=
52
+github.com/ncruces/go-dns v1.3.3/go.mod h1:tuzixNY8PY/M7yUzcvRbUaeLs3ifIdydpi5H2bfRU+s=
53 53
 github.com/panjf2000/ants/v2 v2.12.0 h1:u9JhESo83i/GkZnhfTNuFMMWcNt7mnV1bGJ6FT4wXH8=
54 54
 github.com/panjf2000/ants/v2 v2.12.0/go.mod h1:tSQuaNQ6r6NRhPt+IZVUevvDyFMTs+eS4ztZc52uJTY=
55 55
 github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=

+ 6
- 0
internal/cli/doctor.go Visa fil

@@ -97,6 +97,12 @@ func (d *Doctor) Run(cli *CLI, version string) error {
97 97
 		conf.Network.Timeout.TCP.Get(10*time.Second),
98 98
 		conf.Network.Timeout.HTTP.Get(0),
99 99
 		conf.Network.Timeout.Idle.Get(0),
100
+		net.KeepAliveConfig{
101
+			Enable:   !conf.Network.KeepAlive.Disabled.Get(false),
102
+			Idle:     conf.Network.KeepAlive.Idle.Get(0),
103
+			Interval: conf.Network.KeepAlive.Interval.Get(0),
104
+			Count:    int(conf.Network.KeepAlive.Count.Get(0)),
105
+		},
100 106
 	)
101 107
 
102 108
 	fmt.Println("Validate native network connectivity")

+ 8
- 2
internal/cli/run_proxy.go Visa fil

@@ -5,7 +5,6 @@ import (
5 5
 	"fmt"
6 6
 	"net"
7 7
 	"os"
8
-	"time"
9 8
 
10 9
 	"github.com/9seconds/mtg/v2/antireplay"
11 10
 	"github.com/9seconds/mtg/v2/events"
@@ -51,6 +50,12 @@ func makeNetwork(conf *config.Config, version string) (mtglib.Network, error) {
51 50
 		conf.Network.Timeout.TCP.Get(0),
52 51
 		conf.Network.Timeout.HTTP.Get(0),
53 52
 		conf.Network.Timeout.Idle.Get(0),
53
+		net.KeepAliveConfig{
54
+			Enable:   !conf.Network.KeepAlive.Disabled.Get(false),
55
+			Idle:     conf.Network.KeepAlive.Idle.Get(0),
56
+			Interval: conf.Network.KeepAlive.Interval.Get(0),
57
+			Count:    int(conf.Network.KeepAlive.Count.Get(0)),
58
+		},
54 59
 	)
55 60
 
56 61
 	proxyDialers := make([]mtglib.Network, len(conf.Network.Proxies))
@@ -263,7 +268,8 @@ func runProxy(conf *config.Config, version string) error { //nolint: funlen
263 268
 
264 269
 		AllowFallbackOnUnknownDC: conf.AllowFallbackOnUnknownDC.Get(false),
265 270
 		TolerateTimeSkewness:     conf.TolerateTimeSkewness.Value,
266
-		IdleTimeout:              conf.Network.Timeout.Idle.Get(time.Minute),
271
+		IdleTimeout:              conf.Network.Timeout.Idle.Get(mtglib.DefaultIdleTimeout),
272
+		HandshakeTimeout:         conf.Network.Timeout.Handshake.Get(mtglib.DefaultHandshakeTimeout),
267 273
 
268 274
 		DoppelGangerURLs:    doppelGangerURLs,
269 275
 		DoppelGangerPerRaid: conf.Defense.Doppelganger.Repeats.Get(mtglib.DoppelGangerPerRaid),

+ 10
- 3
internal/config/config.go Visa fil

@@ -60,10 +60,17 @@ type Config struct {
60 60
 	} `json:"defense"`
61 61
 	Network struct {
62 62
 		Timeout struct {
63
-			TCP  TypeDuration `json:"tcp"`
64
-			HTTP TypeDuration `json:"http"`
65
-			Idle TypeDuration `json:"idle"`
63
+			TCP       TypeDuration `json:"tcp"`
64
+			HTTP      TypeDuration `json:"http"`
65
+			Idle      TypeDuration `json:"idle"`
66
+			Handshake TypeDuration `json:"handshake"`
66 67
 		} `json:"timeout"`
68
+		KeepAlive struct {
69
+			Disabled TypeBool        `json:"disabled"`
70
+			Idle     TypeDuration    `json:"idle"`
71
+			Interval TypeDuration    `json:"interval"`
72
+			Count    TypeConcurrency `json:"count"`
73
+		} `json:"keepAlive"`
67 74
 		DOHIP   TypeIP         `json:"dohIp"`
68 75
 		DNS     TypeDNSURI     `json:"dns"`
69 76
 		Proxies []TypeProxyURL `json:"proxies"`

+ 10
- 3
internal/config/parse.go Visa fil

@@ -55,10 +55,17 @@ type tomlConfig struct {
55 55
 	} `toml:"defense" json:"defense,omitempty"`
56 56
 	Network struct {
57 57
 		Timeout struct {
58
-			TCP  string `toml:"tcp" json:"tcp,omitempty"`
59
-			HTTP string `toml:"http" json:"http,omitempty"`
60
-			Idle string `toml:"idle" json:"idle,omitempty"`
58
+			TCP       string `toml:"tcp" json:"tcp,omitempty"`
59
+			HTTP      string `toml:"http" json:"http,omitempty"`
60
+			Idle      string `toml:"idle" json:"idle,omitempty"`
61
+			Handshake string `toml:"handshake" json:"handshake,omitempty"`
61 62
 		} `toml:"timeout" json:"timeout,omitempty"`
63
+		KeepAlive struct {
64
+			Disabled bool   `toml:"disabled" json:"disabled,omitempty"`
65
+			Idle     string `toml:"idle" json:"idle,omitempty"`
66
+			Interval string `toml:"interval" json:"interval,omitempty"`
67
+			Count    uint   `toml:"count" json:"count,omitempty"`
68
+		} `toml:"keep-alive" json:"keepAlive,omitempty"`
62 69
 		DOHIP   string   `toml:"doh-ip" json:"dohIp,omitempty"`
63 70
 		DNS     string   `toml:"dns" json:"dns,omitempty"`
64 71
 		Proxies []string `toml:"proxies" json:"proxies,omitempty"`

+ 7
- 2
mtglib/init.go Visa fil

@@ -77,8 +77,13 @@ const (
77 77
 	// DefaultIdleTimeout is a default timeout for closing a connection in case of
78 78
 	// idling.
79 79
 	//
80
-	// Deprecated: no longer in use because of changed TCP relay algorithm.
81
-	DefaultIdleTimeout = time.Minute
80
+	// Set to 5 minutes to survive typical mobile sleep periods (2-5 min) and
81
+	// avoid racing with MTProto ping_delay_disconnect (~60s interval).
82
+	DefaultIdleTimeout = 5 * time.Minute
83
+
84
+	// DefaultHandshakeTimeout defines a time period during which the
85
+	// all handshake ceremonies must be completed.
86
+	DefaultHandshakeTimeout = 10 * time.Second
82 87
 
83 88
 	// DefaultTolerateTimeSkewness is a default timeout for time skewness on a
84 89
 	// faketls timeout verification.

+ 27
- 13
mtglib/internal/tls/fake/client_side.go Visa fil

@@ -6,6 +6,7 @@ import (
6 6
 	"crypto/sha256"
7 7
 	"crypto/subtle"
8 8
 	"encoding/binary"
9
+	"errors"
9 10
 	"fmt"
10 11
 	"io"
11 12
 	"net"
@@ -20,12 +21,19 @@ const (
20 21
 	// record_type(1) + version(2) + size(2) + handshake_type(1) + uint24_length(3) + client_version(2)
21 22
 	RandomOffset = 1 + 2 + 2 + 1 + 3 + 2
22 23
 
24
+	// https://datatracker.ietf.org/doc/html/rfc8701#name-grease-values
25
+	// https://medium.com/asecuritysite-when-bob-met-alice/in-cybersecurity-what-is-grease-9f8850558dea
26
+	GreaseMask      = 0x0f0f
27
+	GreaseValueType = 0x0a0a
28
+
23 29
 	sniDNSNamesListType = 0
24 30
 )
25 31
 
26 32
 var (
27 33
 	emptyRandom = [RandomLen]byte{}
28 34
 	extTypeSNI  = [2]byte{}
35
+
36
+	ErrCannotFindCipher = errors.New("cannot find a cipher")
29 37
 )
30 38
 
31 39
 type ClientHello struct {
@@ -40,11 +48,6 @@ func ReadClientHello(
40 48
 	hostname string,
41 49
 	tolerateTimeSkewness time.Duration,
42 50
 ) (*ClientHello, error) {
43
-	if err := conn.SetReadDeadline(time.Now().Add(ClientHelloReadTimeout)); err != nil {
44
-		return nil, fmt.Errorf("cannot set read deadline: %w", err)
45
-	}
46
-	defer conn.SetReadDeadline(resetDeadline) //nolint: errcheck
47
-
48 51
 	// This is how FakeTLS is organized:
49 52
 	//  1. We create sha256 HMAC with a given secret
50 53
 	//  2. We dump there a whole TLS frame except of the fact that random
@@ -130,16 +133,27 @@ func parseHandshake(r io.Reader) (*ClientHello, error) {
130 133
 
131 134
 	cipherSuiteLen := int64(binary.BigEndian.Uint16(header[:]))
132 135
 
133
-	// we do not care about picking up any cipher. we pick the first one,
134
-	// so it is always should be present.
135
-	if _, err := io.ReadFull(r, header[:]); err != nil {
136
-		return nil, fmt.Errorf("cannot read first cipher suite: %w", err)
137
-	}
136
+	// Pick the first non-GREASE cipher suite from the list.
137
+	// Real TLS servers never select GREASE values (RFC 8701, pattern 0x?a?a),
138
+	// so echoing them back is a trivial DPI fingerprint.
139
+	// cipherSuiteLen is in bytes; each cipher suite is 2 bytes.
140
+	for range cipherSuiteLen / 2 {
141
+		if _, err := io.ReadFull(r, header[:]); err != nil {
142
+			return nil, fmt.Errorf("cannot read cipher suite: %w", err)
143
+		}
144
+
145
+		if hello.CipherSuite != 0 {
146
+			// do not forget we have to scan until the end
147
+			continue
148
+		}
138 149
 
139
-	hello.CipherSuite = binary.BigEndian.Uint16(header[:])
150
+		if cs := binary.BigEndian.Uint16(header[:]); cs&GreaseMask != GreaseValueType {
151
+			hello.CipherSuite = cs
152
+		}
153
+	}
140 154
 
141
-	if _, err := io.CopyN(io.Discard, r, cipherSuiteLen-2); err != nil {
142
-		return nil, fmt.Errorf("cannot skip remaining cipher suites: %w", err)
155
+	if hello.CipherSuite == 0 {
156
+		return nil, ErrCannotFindCipher
143 157
 	}
144 158
 
145 159
 	if _, err := io.ReadFull(r, header[:1]); err != nil {

+ 0
- 6
mtglib/internal/tls/fake/client_side_snapshot_test.go Visa fil

@@ -12,7 +12,6 @@ import (
12 12
 	"github.com/9seconds/mtg/v2/mtglib"
13 13
 	"github.com/9seconds/mtg/v2/mtglib/internal/tls/fake"
14 14
 	"github.com/stretchr/testify/assert"
15
-	"github.com/stretchr/testify/mock"
16 15
 	"github.com/stretchr/testify/require"
17 16
 	"github.com/stretchr/testify/suite"
18 17
 )
@@ -71,11 +70,6 @@ func (suite *ParseClientHelloSnapshotTestSuite) makeConn(data []byte) *parseClie
71 70
 		readBuf: readBuf,
72 71
 	}
73 72
 
74
-	connMock.
75
-		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
76
-		Twice().
77
-		Return(nil)
78
-
79 73
 	return connMock
80 74
 }
81 75
 

+ 23
- 28
mtglib/internal/tls/fake/client_side_test.go Visa fil

@@ -2,9 +2,9 @@ package fake_test
2 2
 
3 3
 import (
4 4
 	"bytes"
5
+	cryptotls "crypto/tls"
5 6
 	"encoding/binary"
6 7
 	"encoding/json"
7
-	"errors"
8 8
 	"io"
9 9
 	"os"
10 10
 	"testing"
@@ -14,7 +14,6 @@ import (
14 14
 	"github.com/9seconds/mtg/v2/mtglib"
15 15
 	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
16 16
 	"github.com/9seconds/mtg/v2/mtglib/internal/tls/fake"
17
-	"github.com/stretchr/testify/mock"
18 17
 	"github.com/stretchr/testify/require"
19 18
 	"github.com/stretchr/testify/suite"
20 19
 )
@@ -53,11 +52,6 @@ func (suite *ParseClientHelloTestSuite) SetupTest() {
53 52
 	suite.connMock = &parseClientHelloConnMock{
54 53
 		readBuf: suite.readBuf,
55 54
 	}
56
-
57
-	suite.connMock.
58
-		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
59
-		Twice().
60
-		Return(nil)
61 55
 }
62 56
 
63 57
 func (suite *ParseClientHelloTestSuite) TearDownTest() {
@@ -69,23 +63,11 @@ type ParseClientHello_TLSHeaderTestSuite struct {
69 63
 }
70 64
 
71 65
 func (suite *ParseClientHello_TLSHeaderTestSuite) TestEmpty() {
72
-	suite.connMock.ExpectedCalls = []*mock.Call{}
73
-	suite.connMock.
74
-		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
75
-		Once().
76
-		Return(errors.New("fail"))
77
-
78 66
 	_, err := fake.ReadClientHello(suite.connMock, suite.secret.Key[:], suite.secret.Host, TolerateTime)
79
-	suite.ErrorContains(err, "fail")
67
+	suite.ErrorContains(err, "cannot read client hello")
80 68
 }
81 69
 
82 70
 func (suite *ParseClientHello_TLSHeaderTestSuite) TestNothing() {
83
-	suite.connMock.ExpectedCalls = []*mock.Call{}
84
-	suite.connMock.
85
-		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
86
-		Twice().
87
-		Return(nil)
88
-
89 71
 	_, err := fake.ReadClientHello(suite.connMock, suite.secret.Key[:], suite.secret.Host, TolerateTime)
90 72
 	suite.ErrorIs(err, io.EOF)
91 73
 }
@@ -234,12 +216,13 @@ func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotReadCipherSuiteLe
234 216
 }
235 217
 
236 218
 func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotReadFirstCipherSuite() {
237
-	body := make([]byte, 2+fake.RandomLen+1+2)
219
+	body := make([]byte, 2+fake.RandomLen+1+2+1) // cipherSuiteLen=2 but only 1 byte available
220
+	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1:], 2)
238 221
 
239 222
 	suite.writeBody(body)
240 223
 
241 224
 	_, err := fake.ReadClientHello(suite.connMock, suite.secret.Key[:], suite.secret.Host, TolerateTime)
242
-	suite.ErrorContains(err, "cannot read first cipher suite")
225
+	suite.ErrorContains(err, "cannot read cipher suite")
243 226
 }
244 227
 
245 228
 func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotSkipRemainingCipherSuites() {
@@ -249,12 +232,27 @@ func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotSkipRemainingCiph
249 232
 	suite.writeBody(body)
250 233
 
251 234
 	_, err := fake.ReadClientHello(suite.connMock, suite.secret.Key[:], suite.secret.Host, TolerateTime)
252
-	suite.ErrorContains(err, "cannot skip remaining cipher suites")
235
+	suite.ErrorContains(err, "cannot read cipher suite")
236
+}
237
+
238
+func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotFindCipher() {
239
+	// All cipher suites are GREASE values — must return ErrCannotFindCipher.
240
+	body := make([]byte, 2+fake.RandomLen+1+2+4+1)
241
+	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1:], 4)
242
+	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1+2:], 0x0a0a)
243
+	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1+2+2:], 0x1a1a)
244
+	body[2+fake.RandomLen+1+2+4] = 1
245
+
246
+	suite.writeBody(body)
247
+
248
+	_, err := fake.ReadClientHello(suite.connMock, suite.secret.Key[:], suite.secret.Host, TolerateTime)
249
+	suite.ErrorIs(err, fake.ErrCannotFindCipher)
253 250
 }
254 251
 
255 252
 func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotReadCompressionMethodsLength() {
256 253
 	body := make([]byte, 2+fake.RandomLen+1+2+2)
257 254
 	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1:], 2)
255
+	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1+2:], cryptotls.TLS_AES_128_GCM_SHA256)
258 256
 
259 257
 	suite.writeBody(body)
260 258
 
@@ -265,6 +263,7 @@ func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotReadCompressionMe
265 263
 func (suite *ParseClientHelloHandshakeBodyTestSuite) TestCannotSkipCompressionMethods() {
266 264
 	body := make([]byte, 2+fake.RandomLen+1+2+2+1)
267 265
 	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1:], 2)
266
+	binary.BigEndian.PutUint16(body[2+fake.RandomLen+1+2:], cryptotls.TLS_AES_128_GCM_SHA256)
268 267
 	body[2+fake.RandomLen+1+2+2] = 1
269 268
 
270 269
 	suite.writeBody(body)
@@ -300,6 +299,7 @@ func (suite *ParseClientHelloSNITestSuite) writeExtensions(extensions []byte) {
300 299
 	// cipherSuite(2) + compressionLen(1) + compression(1) = 41
301 300
 	body := make([]byte, 41)
302 301
 	binary.BigEndian.PutUint16(body[35:], 2)
302
+	binary.BigEndian.PutUint16(body[37:], cryptotls.TLS_AES_128_GCM_SHA256)
303 303
 	body[39] = 1
304 304
 
305 305
 	suite.readBuf.Write(body)
@@ -478,11 +478,6 @@ func (s *ParseClientHelloFragmentedTestSuite) makeConn(data []byte) *parseClient
478 478
 		readBuf: readBuf,
479 479
 	}
480 480
 
481
-	connMock.
482
-		On("SetReadDeadline", mock.AnythingOfType("time.Time")).
483
-		Twice().
484
-		Return(nil)
485
-
486 481
 	return connMock
487 482
 }
488 483
 

+ 1
- 10
mtglib/internal/tls/fake/init.go Visa fil

@@ -2,15 +2,6 @@ package fake
2 2
 
3 3
 import (
4 4
 	"errors"
5
-	"time"
6 5
 )
7 6
 
8
-const (
9
-	ClientHelloReadTimeout = 5 * time.Second
10
-)
11
-
12
-var (
13
-	resetDeadline time.Time
14
-
15
-	ErrBadDigest = errors.New("incorrect client random")
16
-)
7
+var ErrBadDigest = errors.New("incorrect client random")

+ 1
- 1
mtglib/internal/tls/fake/server_side_test.go Visa fil

@@ -58,7 +58,7 @@ func (suite *SendServerHelloTestSuite) TestRecordStructure() {
58 58
 	recordType, length, err := tls.ReadRecord(suite.buf, &rec)
59 59
 	suite.NoError(err)
60 60
 	suite.Equal(byte(tls.TypeApplicationData), recordType)
61
-	suite.Greater(length, int64(2500))
61
+	suite.GreaterOrEqual(length, int64(2500))
62 62
 
63 63
 	suite.Empty(suite.buf.Bytes())
64 64
 }

+ 8
- 0
mtglib/internal/tls/fake/testdata/client-hello-ok-grease-first.json Visa fil

@@ -0,0 +1,8 @@
1
+{
2
+  "time": 1617181365,
3
+  "random": "w4TaDfYg/aUKdx1oi68vxMKvHJczRNvtRRppLETzeNE=",
4
+  "sessionId": "St2BZ2uHMFn3B2trD1jfdtpjoJOOg6JBeLhFcyCMCq4=",
5
+  "host": "storage.googleapis.com",
6
+  "cipherSuite": 4867,
7
+  "full": "FgMBAgIBAAH+AwPDhNoN9iD9pQp3HWiLry/Ewq8clzNE2+1FGmksRPN40SBK3YFna4cwWfcHa2sPWN922mOgk46DokF4uEVzIIwKrgA2WloTAxMBEwLALMArwCTAI8AKwAnMqcAwwC/AKMAnwBTAE8yoAJ0AnAA9ADwANQAvwAjAEgAKAQABf/8BAAEAAAAAGwAZAAAWc3RvcmFnZS5nb29nbGVhcGlzLmNvbQAXAAAADQAYABYEAwgEBAEFAwIDCAUIBQUBCAYGAQIBAAUABQEAAAAAM3QAAAASAAAAEAAwAC4CaDIFaDItMTYFaDItMTUFaDItMTQIc3BkeS8zLjEGc3BkeS8zCGh0dHAvMS4xAAsAAgEAADMAJgAkAB0AIAf+6C8fSRJSAC7CyUvdR9kDclNR9KLCsCFHpVZ3bC8iAC0AAgEBACsACQgDBAMDAwIDAQAKAAoACAAdABcAGAAZABUAoQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
8
+}

+ 12
- 0
mtglib/proxy.go Visa fil

@@ -28,6 +28,7 @@ type Proxy struct {
28 28
 	allowFallbackOnUnknownDC    bool
29 29
 	tolerateTimeSkewness        time.Duration
30 30
 	idleTimeout                 time.Duration
31
+	handshakeTimeout            time.Duration
31 32
 	domainFrontingPort          int
32 33
 	domainFrontingIP            string
33 34
 	domainFrontingProxyProtocol bool
@@ -66,6 +67,11 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
66 67
 	ctx := newStreamContext(p.ctx, p.logger, conn)
67 68
 	defer ctx.Close()
68 69
 
70
+	if err := ctx.clientConn.SetDeadline(time.Now().Add(p.handshakeTimeout)); err != nil {
71
+		ctx.logger.WarningError("cannot set handshake timeout", err)
72
+		return
73
+	}
74
+
69 75
 	stop := context.AfterFunc(ctx, func() {
70 76
 		ctx.Close()
71 77
 	})
@@ -97,6 +103,11 @@ func (p *Proxy) ServeConn(conn essentials.Conn) {
97 103
 		return
98 104
 	}
99 105
 
106
+	if err := ctx.clientConn.SetDeadline(time.Time{}); err != nil {
107
+		ctx.logger.WarningError("cannot set deadline", err)
108
+		return
109
+	}
110
+
100 111
 	if err := p.doTelegramCall(ctx); err != nil {
101 112
 		ctx.logger.WarningError("cannot dial to telegram", err)
102 113
 		return
@@ -346,6 +357,7 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) {
346 357
 		domainFrontingIP:         opts.DomainFrontingIP,
347 358
 		tolerateTimeSkewness:     opts.getTolerateTimeSkewness(),
348 359
 		idleTimeout:              opts.getIdleTimeout(),
360
+		handshakeTimeout:         opts.getHandshakeTimeout(),
349 361
 		allowFallbackOnUnknownDC: opts.AllowFallbackOnUnknownDC,
350 362
 		telegram:                 tg,
351 363
 		doppelGanger: doppel.NewGanger(

+ 15
- 1
mtglib/proxy_opts.go Visa fil

@@ -70,6 +70,12 @@ type ProxyOpts struct {
70 70
 	// This is an optional setting.
71 71
 	IdleTimeout time.Duration
72 72
 
73
+	// HandshakeTimeout is a timeout during which all handshake ceremonies must
74
+	// be completed, otherwise this process will be aborted
75
+	//
76
+	// This is an optional setting.
77
+	HandshakeTimeout time.Duration
78
+
73 79
 	// TolerateTimeSkewness is a time boundary that defines a time range where
74 80
 	// faketls timestamp is acceptable.
75 81
 	//
@@ -215,9 +221,17 @@ func (p ProxyOpts) getPreferIP() string {
215 221
 	return p.PreferIP
216 222
 }
217 223
 
224
+func (p ProxyOpts) getHandshakeTimeout() time.Duration {
225
+	if p.HandshakeTimeout == 0 {
226
+		return DefaultHandshakeTimeout
227
+	}
228
+
229
+	return p.HandshakeTimeout
230
+}
231
+
218 232
 func (p ProxyOpts) getIdleTimeout() time.Duration {
219 233
 	if p.IdleTimeout == 0 {
220
-		return time.Minute
234
+		return DefaultIdleTimeout
221 235
 	}
222 236
 
223 237
 	return p.IdleTimeout

+ 14
- 0
network/init.go Visa fil

@@ -36,8 +36,22 @@ const (
36 36
 
37 37
 	// DefaultTCPKeepAlivePeriod defines a time period between 2 consequitive
38 38
 	// probes.
39
+	//
40
+	// Deprecated: use DefaultKeepAliveIdle and DefaultKeepAliveInterval instead.
39 41
 	DefaultTCPKeepAlivePeriod = 10 * time.Second
40 42
 
43
+	// DefaultKeepAliveIdle is the time a connection must be idle before
44
+	// the first keepalive probe is sent.
45
+	DefaultKeepAliveIdle = 30 * time.Second
46
+
47
+	// DefaultKeepAliveInterval is the time between consecutive keepalive
48
+	// probes.
49
+	DefaultKeepAliveInterval = 10 * time.Second
50
+
51
+	// DefaultKeepAliveCount is the number of unacknowledged probes before
52
+	// the connection is considered dead.
53
+	DefaultKeepAliveCount = 3
54
+
41 55
 	// ProxyDialerOpenThreshold is used for load balancing SOCKS5 dialer only.
42 56
 	//
43 57
 	// This dialer uses circuit breaker with of 3 stages: OPEN, HALF_OPEN and

+ 7
- 2
network/sockopts.go Visa fil

@@ -20,8 +20,13 @@ func SetServerSocketOptions(conn net.Conn, bufferSize int) error {
20 20
 }
21 21
 
22 22
 func setCommonSocketOptions(conn *net.TCPConn) error {
23
-	if err := conn.SetKeepAlivePeriod(DefaultTCPKeepAlivePeriod); err != nil {
24
-		return fmt.Errorf("cannot set time period of TCP keepalive probes: %w", err)
23
+	if err := conn.SetKeepAliveConfig(net.KeepAliveConfig{
24
+		Enable:   true,
25
+		Idle:     DefaultKeepAliveIdle,
26
+		Interval: DefaultKeepAliveInterval,
27
+		Count:    DefaultKeepAliveCount,
28
+	}); err != nil {
29
+		return fmt.Errorf("cannot configure TCP keepalive: %w", err)
25 30
 	}
26 31
 
27 32
 	if err := conn.SetLinger(tcpLingerTimeout); err != nil {

+ 93
- 0
network/sockopts_test.go Visa fil

@@ -0,0 +1,93 @@
1
+//go:build linux || darwin
2
+// +build linux darwin
3
+
4
+package network_test
5
+
6
+import (
7
+	"net"
8
+	"runtime"
9
+	"syscall"
10
+	"testing"
11
+	"time"
12
+
13
+	"github.com/9seconds/mtg/v2/network"
14
+	"github.com/stretchr/testify/require"
15
+	"golang.org/x/sys/unix"
16
+)
17
+
18
+func tcpKeepIdleOption() int {
19
+	if runtime.GOOS == "darwin" {
20
+		return 0x10 // TCP_KEEPALIVE on macOS
21
+	}
22
+
23
+	return 0x4 // TCP_KEEPIDLE on Linux
24
+}
25
+
26
+func TestSetClientSocketOptionsKeepAlive(t *testing.T) {
27
+	t.Parallel()
28
+
29
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
30
+	require.NoError(t, err)
31
+	defer func() {
32
+		err := listener.Close()
33
+		require.NoError(t, err)
34
+	}()
35
+
36
+	type dialResult struct {
37
+		conn net.Conn
38
+		err  error
39
+	}
40
+
41
+	dialDone := make(chan dialResult, 1)
42
+
43
+	go func() {
44
+		c, err := net.Dial("tcp", listener.Addr().String())
45
+		dialDone <- dialResult{conn: c, err: err}
46
+	}()
47
+
48
+	tcpListener, ok := listener.(*net.TCPListener)
49
+	require.True(t, ok, "listener must be a *net.TCPListener")
50
+
51
+	require.NoError(t, tcpListener.SetDeadline(time.Now().Add(5*time.Second)))
52
+
53
+	accepted, err := listener.Accept()
54
+	require.NoError(t, err)
55
+	defer func() {
56
+		err := accepted.Close()
57
+		require.NoError(t, err)
58
+	}()
59
+
60
+	dr := <-dialDone
61
+	require.NoError(t, dr.err)
62
+	defer func() {
63
+		err := dr.conn.Close()
64
+		require.NoError(t, err)
65
+	}()
66
+
67
+	err = network.SetClientSocketOptions(accepted, 0)
68
+	require.NoError(t, err)
69
+
70
+	tcpConn := accepted.(*net.TCPConn)
71
+
72
+	rawConn, err := tcpConn.SyscallConn()
73
+	require.NoError(t, err)
74
+
75
+	err = rawConn.Control(func(fd uintptr) {
76
+		val, err := unix.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
77
+		require.NoError(t, err)
78
+		require.NotEqual(t, 0, val, "SO_KEEPALIVE should be enabled")
79
+
80
+		idle, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, tcpKeepIdleOption())
81
+		require.NoError(t, err)
82
+		require.Equal(t, int(network.DefaultKeepAliveIdle.Seconds()), idle)
83
+
84
+		interval, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPINTVL)
85
+		require.NoError(t, err)
86
+		require.Equal(t, int(network.DefaultKeepAliveInterval.Seconds()), interval)
87
+
88
+		count, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPCNT)
89
+		require.NoError(t, err)
90
+		require.Equal(t, network.DefaultKeepAliveCount, count)
91
+	})
92
+	require.NoError(t, err)
93
+}

+ 1
- 1
network/v2/base_http_test.go Visa fil

@@ -25,7 +25,7 @@ func (suite *BaseHTTPTestSuite) SetupSuite() {
25 25
 }
26 26
 
27 27
 func (suite *BaseHTTPTestSuite) SetupTest() {
28
-	suite.client = network.New(nil, "mtg/1", 0, 0, 0).MakeHTTPClient(nil)
28
+	suite.client = network.New(nil, "mtg/1", 0, 0, 0, network.DefaultKeepAliveConfig).MakeHTTPClient(nil)
29 29
 }
30 30
 
31 31
 func (suite *BaseHTTPTestSuite) TestGet() {

+ 1
- 1
network/v2/base_network_test.go Visa fil

@@ -19,7 +19,7 @@ type BaseNetworkTestSuite struct {
19 19
 func (suite *BaseNetworkTestSuite) SetupSuite() {
20 20
 	suite.EchoServerTestSuite.SetupSuite()
21 21
 
22
-	suite.net = network.New(nil, "agent", 0, 0, 0)
22
+	suite.net = network.New(nil, "agent", 0, 0, 0, network.DefaultKeepAliveConfig)
23 23
 }
24 24
 
25 25
 func (suite *BaseNetworkTestSuite) TestDialUnknownNetwork() {

+ 43
- 1
network/v2/init.go Visa fil

@@ -11,6 +11,7 @@ package network
11 11
 
12 12
 import (
13 13
 	"errors"
14
+	"net"
14 15
 	"time"
15 16
 )
16 17
 
@@ -26,14 +27,55 @@ const (
26 27
 
27 28
 	// DefaultTCPKeepAlivePeriod defines a time period between 2 consecuitive
28 29
 	// probes.
30
+	//
31
+	// Deprecated: use DefaultKeepAliveConfig
29 32
 	DefaultTCPKeepAlivePeriod = 10 * time.Second
30 33
 
34
+	// DefaultKeepAliveIdle is the time a connection must be idle before
35
+	// the first keepalive probe is sent.
36
+	//
37
+	// Deprecated: use DefaultKeepAliveConfig
38
+	DefaultKeepAliveIdle = 30 * time.Second
39
+
40
+	// DefaultKeepAliveInterval is the time between consecutive keepalive
41
+	// probes.
42
+	//
43
+	// Deprecated: use DefaultKeepAliveConfig
44
+	DefaultKeepAliveInterval = 10 * time.Second
45
+
46
+	// DefaultKeepAliveCount is the number of unacknowledged probes before
47
+	// the connection is considered dead.
48
+	//
49
+	// Deprecated: use DefaultKeepAliveConfig
50
+	DefaultKeepAliveCount = 3
51
+
31 52
 	// User Agent to use in HTTP client.
32 53
 	UserAgent = "curl/8.5.0"
33 54
 
34 55
 	// tcpLingerTimeout defines a number of seconds to wait for sending
35 56
 	// unacknowledged data.
36 57
 	tcpLingerTimeout = 1
58
+
59
+	// tcpNotSentLowat limits the amount of unsent data queued in the
60
+	// kernel write buffer per socket. When the unsent data drops below
61
+	// this threshold, the socket becomes writable again. This reduces
62
+	// per-connection memory usage and bufferbloat by applying
63
+	// back-pressure to the relay loop instead of piling up data in
64
+	// kernel buffers.
65
+	tcpNotSentLowat = 128 * 1024
37 66
 )
38 67
 
39
-var ErrCannotDial = errors.New("cannot dial to any address")
68
+var (
69
+	ErrCannotDial = errors.New("cannot dial to any address")
70
+
71
+	// DefaultKeepAliveConfig defines a default configuration for
72
+	// keep alive settings. As per official documentation, if keep alive
73
+	// is enabled, then:
74
+	//
75
+	//  Idle = 15 * time.Second
76
+	//  Interval = 15 * time.Second
77
+	//  Count = 9
78
+	DefaultKeepAliveConfig = net.KeepAliveConfig{
79
+		Enable: true,
80
+	}
81
+)

+ 10
- 7
network/v2/network.go Visa fil

@@ -14,9 +14,10 @@ import (
14 14
 type network struct {
15 15
 	net.Dialer
16 16
 
17
-	httpTimeout time.Duration
18
-	idleTimeout time.Duration
19
-	userAgent   string
17
+	keepAliveConfig net.KeepAliveConfig
18
+	httpTimeout     time.Duration
19
+	idleTimeout     time.Duration
20
+	userAgent       string
20 21
 }
21 22
 
22 23
 func (n *network) Dial(network, address string) (essentials.Conn, error) {
@@ -37,7 +38,7 @@ func (n *network) DialContext(ctx context.Context, network, address string) (ess
37 38
 
38 39
 	tcpConn := conn.(*net.TCPConn)
39 40
 
40
-	return tcpConn, setCommonSocketOptions(tcpConn)
41
+	return tcpConn, setCommonSocketOptions(tcpConn, n.keepAliveConfig)
41 42
 }
42 43
 
43 44
 func (n *network) MakeHTTPClient(
@@ -71,6 +72,7 @@ func New(
71 72
 	tcpTimeout,
72 73
 	httpTimeout,
73 74
 	idleTimeout time.Duration,
75
+	keepAliveConfig net.KeepAliveConfig,
74 76
 ) mtglib.Network {
75 77
 	if dnsResolver == nil {
76 78
 		dnsResolver = net.DefaultResolver
@@ -86,8 +88,9 @@ func New(
86 88
 			Resolver:      dnsResolver,
87 89
 			FallbackDelay: -1,
88 90
 		},
89
-		userAgent:   userAgent,
90
-		idleTimeout: idleTimeout,
91
-		httpTimeout: httpTimeout,
91
+		userAgent:       userAgent,
92
+		idleTimeout:     idleTimeout,
93
+		httpTimeout:     httpTimeout,
94
+		keepAliveConfig: keepAliveConfig,
92 95
 	}
93 96
 }

+ 7
- 3
network/v2/sockopts.go Visa fil

@@ -5,9 +5,9 @@ import (
5 5
 	"net"
6 6
 )
7 7
 
8
-func setCommonSocketOptions(conn *net.TCPConn) error {
9
-	if err := conn.SetKeepAlivePeriod(DefaultTCPKeepAlivePeriod); err != nil {
10
-		return fmt.Errorf("cannot set time period of TCP keepalive probes: %w", err)
8
+func setCommonSocketOptions(conn *net.TCPConn, keepAliveConfig net.KeepAliveConfig) error {
9
+	if err := conn.SetKeepAliveConfig(keepAliveConfig); err != nil {
10
+		return fmt.Errorf("cannot configure TCP keepalive: %w", err)
11 11
 	}
12 12
 
13 13
 	if err := conn.SetLinger(tcpLingerTimeout); err != nil {
@@ -23,5 +23,9 @@ func setCommonSocketOptions(conn *net.TCPConn) error {
23 23
 		return fmt.Errorf("cannot setup SO_REUSEADDR/PORT: %w", err)
24 24
 	}
25 25
 
26
+	setCongestionControl(rawConn)
27
+	setTCPUserTimeout(rawConn, keepAliveConfig)
28
+	setNotSentLowat(rawConn)
29
+
26 30
 	return nil
27 31
 }

+ 20
- 0
network/v2/sockopts_congestion.go Visa fil

@@ -0,0 +1,20 @@
1
+//go:build linux
2
+
3
+package network
4
+
5
+import (
6
+	"syscall"
7
+
8
+	"golang.org/x/sys/unix"
9
+)
10
+
11
+// setCongestionControl sets BBR as the TCP congestion control algorithm.
12
+// BBR provides better throughput over lossy and high-latency links compared
13
+// to the default cubic, which is especially beneficial for mobile and
14
+// home internet clients. This is best-effort: silently ignored if the
15
+// kernel does not have tcp_bbr available.
16
+func setCongestionControl(conn syscall.RawConn) {
17
+	conn.Control(func(fd uintptr) { //nolint: errcheck
18
+		unix.SetsockoptString(int(fd), unix.IPPROTO_TCP, unix.TCP_CONGESTION, "bbr") //nolint: errcheck
19
+	})
20
+}

+ 7
- 0
network/v2/sockopts_congestion_stub.go Visa fil

@@ -0,0 +1,7 @@
1
+//go:build !linux
2
+
3
+package network
4
+
5
+import "syscall"
6
+
7
+func setCongestionControl(conn syscall.RawConn) {}

+ 20
- 0
network/v2/sockopts_lowat.go Visa fil

@@ -0,0 +1,20 @@
1
+//go:build linux || darwin
2
+
3
+package network
4
+
5
+import (
6
+	"syscall"
7
+
8
+	"golang.org/x/sys/unix"
9
+)
10
+
11
+// setNotSentLowat sets TCP_NOTSENT_LOWAT which limits the amount of
12
+// unsent data queued in the kernel write buffer. Once unsent data drops
13
+// below this threshold the socket becomes writable again, applying
14
+// back-pressure to the relay loop instead of piling up data in kernel
15
+// buffers. This reduces per-connection memory and bufferbloat.
16
+func setNotSentLowat(conn syscall.RawConn) {
17
+	conn.Control(func(fd uintptr) { //nolint: errcheck
18
+		unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_NOTSENT_LOWAT, tcpNotSentLowat) //nolint: errcheck
19
+	})
20
+}

+ 7
- 0
network/v2/sockopts_lowat_stub.go Visa fil

@@ -0,0 +1,7 @@
1
+//go:build !linux && !darwin
2
+
3
+package network
4
+
5
+import "syscall"
6
+
7
+func setNotSentLowat(conn syscall.RawConn) {}

network/v2/sockopts_unix.go → network/v2/sockopts_reuseaddr.go Visa fil

@@ -1,5 +1,4 @@
1 1
 //go:build !windows
2
-// +build !windows
3 2
 
4 3
 package network
5 4
 

network/v2/sockopts_windows.go → network/v2/sockopts_reuseaddr_stub.go Visa fil

@@ -1,5 +1,4 @@
1 1
 //go:build windows
2
-// +build windows
3 2
 
4 3
 package network
5 4
 

+ 92
- 0
network/v2/sockopts_test.go Visa fil

@@ -0,0 +1,92 @@
1
+//go:build linux || darwin
2
+// +build linux darwin
3
+
4
+package network
5
+
6
+import (
7
+	"net"
8
+	"runtime"
9
+	"syscall"
10
+	"testing"
11
+	"time"
12
+
13
+	"github.com/stretchr/testify/require"
14
+	"golang.org/x/sys/unix"
15
+)
16
+
17
+func tcpKeepIdleOption() int {
18
+	if runtime.GOOS == "darwin" {
19
+		return 0x10 // TCP_KEEPALIVE on macOS
20
+	}
21
+
22
+	return 0x4 // TCP_KEEPIDLE on Linux
23
+}
24
+
25
+func TestSetCommonSocketOptionsKeepAlive(t *testing.T) {
26
+	t.Parallel()
27
+
28
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
29
+	require.NoError(t, err)
30
+	defer func() {
31
+		err := listener.Close()
32
+		require.NoError(t, err)
33
+	}()
34
+
35
+	type dialResult struct {
36
+		conn net.Conn
37
+		err  error
38
+	}
39
+
40
+	dialDone := make(chan dialResult, 1)
41
+
42
+	go func() {
43
+		c, err := net.Dial("tcp", listener.Addr().String())
44
+		dialDone <- dialResult{conn: c, err: err}
45
+	}()
46
+
47
+	tcpListener, ok := listener.(*net.TCPListener)
48
+	require.True(t, ok, "listener must be a *net.TCPListener")
49
+
50
+	require.NoError(t, tcpListener.SetDeadline(time.Now().Add(5*time.Second)))
51
+
52
+	accepted, err := listener.Accept()
53
+	require.NoError(t, err)
54
+	defer func() {
55
+		err := accepted.Close()
56
+		require.NoError(t, err)
57
+	}()
58
+
59
+	dr := <-dialDone
60
+	require.NoError(t, dr.err)
61
+	defer func() {
62
+		err := dr.conn.Close()
63
+		require.NoError(t, err)
64
+	}()
65
+
66
+	tcpConn := accepted.(*net.TCPConn)
67
+
68
+	err = setCommonSocketOptions(tcpConn, DefaultKeepAliveConfig)
69
+	require.NoError(t, err)
70
+
71
+	rawConn, err := tcpConn.SyscallConn()
72
+	require.NoError(t, err)
73
+
74
+	err = rawConn.Control(func(fd uintptr) {
75
+		val, err := unix.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
76
+		require.NoError(t, err)
77
+		require.NotEqual(t, 0, val, "SO_KEEPALIVE should be enabled")
78
+
79
+		idle, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, tcpKeepIdleOption())
80
+		require.NoError(t, err)
81
+		require.Equal(t, 15, idle, "keepalive idle should match DefaultKeepAliveIdle")
82
+
83
+		interval, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPINTVL)
84
+		require.NoError(t, err)
85
+		require.Equal(t, 15, interval, "keepalive interval should match DefaultKeepAliveInterval")
86
+
87
+		count, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPCNT)
88
+		require.NoError(t, err)
89
+		require.Equal(t, 9, count, "keepalive count should match DefaultKeepAliveCount")
90
+	})
91
+	require.NoError(t, err)
92
+}

+ 48
- 0
network/v2/sockopts_usertimeout.go Visa fil

@@ -0,0 +1,48 @@
1
+//go:build linux
2
+
3
+package network
4
+
5
+import (
6
+	"net"
7
+	"syscall"
8
+	"time"
9
+
10
+	"golang.org/x/sys/unix"
11
+)
12
+
13
+// Go runtime defaults for KeepAliveConfig when fields are zero.
14
+const (
15
+	goDefaultKeepAliveIdle     = 15 * time.Second
16
+	goDefaultKeepAliveInterval = 15 * time.Second
17
+	goDefaultKeepAliveCount    = 9
18
+)
19
+
20
+// setTCPUserTimeout sets TCP_USER_TIMEOUT on a socket. If transmitted
21
+// data remains unacknowledged for this long, the kernel closes the
22
+// connection. As recommended by Cloudflare
23
+// (https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/),
24
+// the value is computed as: keepidle + keepintvl * keepcnt. This
25
+// ensures TCP_USER_TIMEOUT and keepalives agree on when to give up.
26
+// Best-effort: silently ignored if unsupported.
27
+func setTCPUserTimeout(conn syscall.RawConn, cfg net.KeepAliveConfig) {
28
+	idle := cfg.Idle
29
+	if idle == 0 {
30
+		idle = goDefaultKeepAliveIdle
31
+	}
32
+
33
+	interval := cfg.Interval
34
+	if interval == 0 {
35
+		interval = goDefaultKeepAliveInterval
36
+	}
37
+
38
+	count := cfg.Count
39
+	if count == 0 {
40
+		count = goDefaultKeepAliveCount
41
+	}
42
+
43
+	timeout := idle + interval*time.Duration(count)
44
+
45
+	conn.Control(func(fd uintptr) { //nolint: errcheck
46
+		unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int(timeout.Milliseconds())) //nolint: errcheck
47
+	})
48
+}

+ 10
- 0
network/v2/sockopts_usertimeout_stub.go Visa fil

@@ -0,0 +1,10 @@
1
+//go:build !linux
2
+
3
+package network
4
+
5
+import (
6
+	"net"
7
+	"syscall"
8
+)
9
+
10
+func setTCPUserTimeout(conn syscall.RawConn, cfg net.KeepAliveConfig) {}

+ 1
- 1
network/v2/socks_proxy_test.go Visa fil

@@ -66,7 +66,7 @@ func (suite *SocksProxyTestSuite) SetupSuite() {
66 66
 	require.NoError(suite.T(), err)
67 67
 	suite.authURL = parsed
68 68
 
69
-	suite.baseNetwork = network.New(nil, "mtg", 0, 0, 0)
69
+	suite.baseNetwork = network.New(nil, "mtg", 0, 0, 0, network.DefaultKeepAliveConfig)
70 70
 }
71 71
 
72 72
 func (suite *SocksProxyTestSuite) TestIncorrectSchema() {

Laddar…
Avbryt
Spara