| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- //go:build linux || darwin
- // +build linux darwin
-
- package network_test
-
- import (
- "net"
- "runtime"
- "syscall"
- "testing"
- "time"
-
- "github.com/dolonet/mtg-multi/network"
- "github.com/stretchr/testify/require"
- "golang.org/x/sys/unix"
- )
-
- func tcpKeepIdleOption() int {
- if runtime.GOOS == "darwin" {
- return 0x10 // TCP_KEEPALIVE on macOS
- }
-
- return 0x4 // TCP_KEEPIDLE on Linux
- }
-
- func TestSetClientSocketOptionsKeepAlive(t *testing.T) {
- t.Parallel()
-
- listener, err := net.Listen("tcp", "127.0.0.1:0")
- require.NoError(t, err)
- defer func() {
- err := listener.Close()
- require.NoError(t, err)
- }()
-
- type dialResult struct {
- conn net.Conn
- err error
- }
-
- dialDone := make(chan dialResult, 1)
-
- go func() {
- c, err := net.Dial("tcp", listener.Addr().String())
- dialDone <- dialResult{conn: c, err: err}
- }()
-
- tcpListener, ok := listener.(*net.TCPListener)
- require.True(t, ok, "listener must be a *net.TCPListener")
-
- require.NoError(t, tcpListener.SetDeadline(time.Now().Add(5*time.Second)))
-
- accepted, err := listener.Accept()
- require.NoError(t, err)
- defer func() {
- err := accepted.Close()
- require.NoError(t, err)
- }()
-
- dr := <-dialDone
- require.NoError(t, dr.err)
- defer func() {
- err := dr.conn.Close()
- require.NoError(t, err)
- }()
-
- err = network.SetClientSocketOptions(accepted, 0)
- require.NoError(t, err)
-
- tcpConn := accepted.(*net.TCPConn)
-
- rawConn, err := tcpConn.SyscallConn()
- require.NoError(t, err)
-
- err = rawConn.Control(func(fd uintptr) {
- val, err := unix.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
- require.NoError(t, err)
- require.NotEqual(t, 0, val, "SO_KEEPALIVE should be enabled")
-
- idle, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, tcpKeepIdleOption())
- require.NoError(t, err)
- require.Equal(t, int(network.DefaultKeepAliveIdle.Seconds()), idle)
-
- interval, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPINTVL)
- require.NoError(t, err)
- require.Equal(t, int(network.DefaultKeepAliveInterval.Seconds()), interval)
-
- count, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPCNT)
- require.NoError(t, err)
- require.Equal(t, network.DefaultKeepAliveCount, count)
- })
- require.NoError(t, err)
- }
|