| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- //go:build linux || darwin
- // +build linux darwin
-
- package network
-
- import (
- "net"
- "runtime"
- "syscall"
- "testing"
- "time"
-
- "github.com/stretchr/testify/require"
- "golang.org/x/sys/unix"
- )
-
- func tcpKeepIdleOption() int {
- if runtime.GOOS == "darwin" {
- return 0x10 // TCP_KEEPALIVE on macOS
- }
-
- return 0x4 // TCP_KEEPIDLE on Linux
- }
-
- func TestSetCommonSocketOptionsKeepAlive(t *testing.T) {
- t.Parallel()
-
- listener, err := net.Listen("tcp", "127.0.0.1:0")
- require.NoError(t, err)
- defer func() {
- err := listener.Close()
- require.NoError(t, err)
- }()
-
- type dialResult struct {
- conn net.Conn
- err error
- }
-
- dialDone := make(chan dialResult, 1)
-
- go func() {
- c, err := net.Dial("tcp", listener.Addr().String())
- dialDone <- dialResult{conn: c, err: err}
- }()
-
- tcpListener, ok := listener.(*net.TCPListener)
- require.True(t, ok, "listener must be a *net.TCPListener")
-
- require.NoError(t, tcpListener.SetDeadline(time.Now().Add(5*time.Second)))
-
- accepted, err := listener.Accept()
- require.NoError(t, err)
- defer func() {
- err := accepted.Close()
- require.NoError(t, err)
- }()
-
- dr := <-dialDone
- require.NoError(t, dr.err)
- defer func() {
- err := dr.conn.Close()
- require.NoError(t, err)
- }()
-
- tcpConn := accepted.(*net.TCPConn)
-
- err = setCommonSocketOptions(tcpConn, DefaultKeepAliveConfig, DefaultTCPNotSentLowat)
- require.NoError(t, err)
-
- rawConn, err := tcpConn.SyscallConn()
- require.NoError(t, err)
-
- err = rawConn.Control(func(fd uintptr) {
- val, err := unix.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
- require.NoError(t, err)
- require.NotEqual(t, 0, val, "SO_KEEPALIVE should be enabled")
-
- idle, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, tcpKeepIdleOption())
- require.NoError(t, err)
- require.Equal(t, 15, idle, "keepalive idle should match DefaultKeepAliveIdle")
-
- interval, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPINTVL)
- require.NoError(t, err)
- require.Equal(t, 15, interval, "keepalive interval should match DefaultKeepAliveInterval")
-
- count, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPCNT)
- require.NoError(t, err)
- require.Equal(t, 9, count, "keepalive count should match DefaultKeepAliveCount")
- })
- require.NoError(t, err)
- }
-
- func TestSetCommonSocketOptionsNotSentLowat(t *testing.T) {
- t.Parallel()
-
- cases := []struct {
- name string
- want int
- }{
- {name: "default", want: DefaultTCPNotSentLowat},
- {name: "custom", want: 4 * 1024 * 1024},
- }
-
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- listener, err := net.Listen("tcp", "127.0.0.1:0")
- require.NoError(t, err)
- defer listener.Close() //nolint: errcheck
-
- dialDone := make(chan struct{})
-
- go func() {
- c, err := net.Dial("tcp", listener.Addr().String())
- if err == nil {
- defer c.Close() //nolint: errcheck
- }
- close(dialDone)
- }()
-
- require.NoError(t, listener.(*net.TCPListener).SetDeadline(time.Now().Add(5*time.Second)))
-
- accepted, err := listener.Accept()
- require.NoError(t, err)
- defer accepted.Close() //nolint: errcheck
-
- <-dialDone
-
- tcpConn := accepted.(*net.TCPConn)
-
- require.NoError(t, setCommonSocketOptions(tcpConn, DefaultKeepAliveConfig, tc.want))
-
- rawConn, err := tcpConn.SyscallConn()
- require.NoError(t, err)
-
- err = rawConn.Control(func(fd uintptr) {
- got, err := unix.GetsockoptInt(int(fd), unix.IPPROTO_TCP, unix.TCP_NOTSENT_LOWAT)
- require.NoError(t, err)
- require.Equal(t, tc.want, got, "TCP_NOTSENT_LOWAT should match value passed to setCommonSocketOptions")
- })
- require.NoError(t, err)
- })
- }
- }
|