| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- package mtglib
-
- import (
- "bufio"
- "bytes"
- "context"
- "errors"
- "io"
- "net"
- "testing"
- "time"
-
- "github.com/9seconds/mtg/v2/internal/testlib"
- "github.com/pires/go-proxyproto"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/suite"
- )
-
- type ConnRewindBaseConn struct {
- testlib.EssentialsConnMock
-
- readBuffer bytes.Buffer
- }
-
- func (c *ConnRewindBaseConn) Read(p []byte) (int, error) {
- c.Called(p)
-
- return c.readBuffer.Read(p) //nolint: wrapcheck
- }
-
- type ConnTrafficTestSuite struct {
- suite.Suite
-
- eventStreamMock *EventStreamMock
- connMock *testlib.EssentialsConnMock
- conn io.ReadWriter
- }
-
- func (suite *ConnTrafficTestSuite) SetupTest() {
- suite.eventStreamMock = &EventStreamMock{}
- suite.connMock = &testlib.EssentialsConnMock{}
- suite.conn = connTraffic{
- Conn: suite.connMock,
- streamID: "CONNID",
- ctx: context.Background(),
- stream: suite.eventStreamMock,
- }
- }
-
- func (suite *ConnTrafficTestSuite) TearDownTest() {
- suite.eventStreamMock.AssertExpectations(suite.T())
- suite.connMock.AssertExpectations(suite.T())
- }
-
- func (suite *ConnTrafficTestSuite) TestReadOk() {
- suite.eventStreamMock.
- On("Send", mock.Anything, mock.Anything).
- Once().
- Run(func(args mock.Arguments) {
- evt, ok := args.Get(1).(EventTraffic)
-
- suite.True(ok)
- suite.Equal("CONNID", evt.StreamID())
- suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
- suite.EqualValues(10, evt.Traffic)
- suite.True(evt.IsRead)
- })
- suite.connMock.On("Read", mock.Anything).Once().Return(10, nil)
-
- n, err := suite.conn.Read(make([]byte, 10))
- suite.NoError(err)
- suite.Equal(10, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestReadErr() { //nolint: dupl
- suite.eventStreamMock.
- On("Send", mock.Anything, mock.Anything).
- Once().
- Run(func(args mock.Arguments) {
- evt, ok := args.Get(1).(EventTraffic)
-
- suite.True(ok)
- suite.Equal("CONNID", evt.StreamID())
- suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
- suite.EqualValues(10, evt.Traffic)
- suite.True(evt.IsRead)
- })
- suite.connMock.On("Read", mock.Anything).Once().Return(10, io.EOF)
-
- n, err := suite.conn.Read(make([]byte, 10))
- suite.True(errors.Is(err, io.EOF))
- suite.Equal(10, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestReadNothingOk() {
- suite.connMock.On("Read", mock.Anything).Once().Return(0, nil)
-
- n, err := suite.conn.Read(make([]byte, 10))
- suite.NoError(err)
- suite.Equal(0, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestReadNothingErr() {
- suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
-
- n, err := suite.conn.Read(make([]byte, 10))
- suite.True(errors.Is(err, io.EOF))
- suite.Equal(0, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestWriteOk() {
- suite.eventStreamMock.
- On("Send", mock.Anything, mock.Anything).
- Once().
- Run(func(args mock.Arguments) {
- evt, ok := args.Get(1).(EventTraffic)
-
- suite.True(ok)
- suite.Equal("CONNID", evt.StreamID())
- suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
- suite.EqualValues(10, evt.Traffic)
- suite.False(evt.IsRead)
- })
- suite.connMock.On("Write", mock.Anything).Once().Return(10, nil)
-
- n, err := suite.conn.Write(make([]byte, 10))
- suite.NoError(err)
- suite.Equal(10, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestWriteErr() { //nolint: dupl
- suite.eventStreamMock.
- On("Send", mock.Anything, mock.Anything).
- Once().
- Run(func(args mock.Arguments) {
- evt, ok := args.Get(1).(EventTraffic)
-
- suite.True(ok)
- suite.Equal("CONNID", evt.StreamID())
- suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
- suite.EqualValues(10, evt.Traffic)
- suite.False(evt.IsRead)
- })
- suite.connMock.On("Write", mock.Anything).Once().Return(10, io.EOF)
-
- n, err := suite.conn.Write(make([]byte, 10))
- suite.True(errors.Is(err, io.EOF))
- suite.Equal(10, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestWriteNothingOk() {
- suite.connMock.On("Write", mock.Anything).Once().Return(0, nil)
-
- n, err := suite.conn.Write(make([]byte, 10))
- suite.NoError(err)
- suite.Equal(0, n)
- }
-
- func (suite *ConnTrafficTestSuite) TestWriteNothingErr() {
- suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
-
- n, err := suite.conn.Write(make([]byte, 10))
- suite.True(errors.Is(err, io.EOF))
- suite.Equal(0, n)
- }
-
- type ConnRewindTestSuite struct {
- suite.Suite
-
- connMock *ConnRewindBaseConn
- conn *connRewind
- }
-
- func (suite *ConnRewindTestSuite) SetupTest() {
- suite.connMock = &ConnRewindBaseConn{}
- suite.conn = newConnRewind(suite.connMock)
- }
-
- func (suite *ConnRewindTestSuite) TearDownTest() {
- suite.connMock.AssertExpectations(suite.T())
- }
-
- func (suite *ConnRewindTestSuite) TestRead() {
- suite.connMock.On("Read", mock.Anything)
- suite.connMock.readBuffer.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
-
- buf := make([]byte, 2)
-
- n, err := suite.conn.Read(buf)
- suite.NoError(err)
- suite.Equal(2, n)
- suite.Equal([]byte{1, 2}, buf)
-
- n, err = suite.conn.Read(buf)
- suite.NoError(err)
- suite.Equal(2, n)
- suite.Equal([]byte{3, 4}, buf)
-
- suite.conn.Rewind()
-
- data, err := io.ReadAll(suite.conn)
- suite.NoError(err)
- suite.Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, data)
- }
-
- type ConnProxyProtocolTestSuite struct {
- suite.Suite
-
- sourceConnMock *testlib.EssentialsConnMock
- targetConnMock *testlib.EssentialsConnMock
- conn *connProxyProtocol
- }
-
- func (suite *ConnProxyProtocolTestSuite) SetupTest() {
- suite.sourceConnMock = &testlib.EssentialsConnMock{}
- suite.targetConnMock = &testlib.EssentialsConnMock{}
-
- localAddr := &net.TCPAddr{
- IP: net.ParseIP("127.0.0.1").To4(),
- }
- remoteAddr := &net.TCPAddr{
- IP: net.ParseIP("127.0.0.2").To4(),
- }
-
- suite.sourceConnMock.
- On("RemoteAddr").
- Return(localAddr)
- suite.targetConnMock.
- On("RemoteAddr").
- Maybe().
- Return(remoteAddr)
-
- suite.conn = newConnProxyProtocol(suite.sourceConnMock, suite.targetConnMock)
- }
-
- func (suite *ConnProxyProtocolTestSuite) TestRead() {
- value := []byte{1, 2, 3, 4, 5}
- toRead := make([]byte, len(value))
-
- suite.targetConnMock.
- On("Read", mock.AnythingOfType("[]uint8")).
- Once().
- Return(len(toRead), nil).
- Run(func(args mock.Arguments) {
- arr := args.Get(0).([]byte)
- copy(arr, value)
- })
-
- n, err := suite.conn.Read(toRead)
- suite.Equal(len(value), n)
- suite.NoError(err)
- suite.Equal(value, toRead)
- }
-
- func (suite *ConnProxyProtocolTestSuite) TestWrite() {
- value := []byte{1, 2, 3, 4, 5}
- buf := &bytes.Buffer{}
- bufReader := bufio.NewReader(buf)
-
- suite.targetConnMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(28, nil).
- Run(func(args mock.Arguments) {
- arr := args.Get(0).([]byte)
- buf.Write(arr)
- })
-
- _, err := suite.conn.Write(value)
- suite.NoError(err)
-
- header, err := proxyproto.Read(bufReader)
- suite.NoError(err)
-
- sourceAddr, destAddr, ok := header.TCPAddrs()
- suite.True(ok)
- suite.Equal(suite.sourceConnMock.RemoteAddr(), sourceAddr)
- suite.Equal(suite.targetConnMock.RemoteAddr(), destAddr)
-
- read, _ := io.ReadAll(bufReader)
- suite.Equal(value, read)
-
- _, err = suite.conn.Write(value)
- suite.NoError(err)
-
- read, _ = io.ReadAll(bufReader)
- suite.Equal(value, read)
- }
-
- func (suite *ConnProxyProtocolTestSuite) TearDownTest() {
- suite.sourceConnMock.AssertExpectations(suite.T())
- suite.targetConnMock.AssertExpectations(suite.T())
- }
-
- func TestConnTraffic(t *testing.T) {
- t.Parallel()
- suite.Run(t, &ConnTrafficTestSuite{})
- }
-
- func TestConnRewind(t *testing.T) {
- t.Parallel()
- suite.Run(t, &ConnRewindTestSuite{})
- }
-
- func TestConnProxyProtocol(t *testing.T) {
- t.Parallel()
- suite.Run(t, &ConnProxyProtocolTestSuite{})
- }
|