| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- package doppel
-
- import (
- "bytes"
- "context"
- "encoding/binary"
- "errors"
- "io"
- "sync"
- "testing"
- "time"
-
- "github.com/9seconds/mtg/v2/internal/testlib"
- "github.com/9seconds/mtg/v2/mtglib/internal/tls"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/suite"
- )
-
- type ConnMock struct {
- testlib.EssentialsConnMock
-
- mu sync.Mutex
- writeBuffer bytes.Buffer
- }
-
- func (m *ConnMock) Write(p []byte) (int, error) {
- args := m.Called(p)
- if err := args.Error(1); err != nil {
- return args.Int(0), err
- }
-
- m.mu.Lock()
- defer m.mu.Unlock()
-
- return m.writeBuffer.Write(p)
- }
-
- func (m *ConnMock) Written() []byte {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- return bytes.Clone(m.writeBuffer.Bytes())
- }
-
- type ConnTestSuite struct {
- suite.Suite
-
- connMock *ConnMock
- ctx context.Context
- ctxCancel context.CancelFunc
- }
-
- func (suite *ConnTestSuite) SetupTest() {
- ctx, cancel := context.WithCancel(context.Background())
- suite.ctx = ctx
- suite.ctxCancel = cancel
- suite.connMock = &ConnMock{}
- }
-
- func (suite *ConnTestSuite) TearDownTest() {
- suite.ctxCancel()
- suite.connMock.AssertExpectations(suite.T())
- }
-
- func (suite *ConnTestSuite) makeConn() Conn {
- return NewConn(suite.ctx, suite.connMock, &Stats{
- k: 2.0,
- lambda: 0.01,
- })
- }
-
- func (suite *ConnTestSuite) TestWriteBuffersData() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
- defer c.Stop()
-
- n, err := c.Write([]byte{1, 2, 3})
- suite.NoError(err)
- suite.Equal(3, n)
- }
-
- func (suite *ConnTestSuite) TestWriteOutputsTLSRecords() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
-
- payload := []byte("hello doppelganger")
- _, err := c.Write(payload)
- suite.NoError(err)
-
- suite.Eventually(func() bool {
- return len(suite.connMock.Written()) > 0
- }, 2*time.Second, time.Millisecond)
-
- c.Stop()
-
- assembled := &bytes.Buffer{}
- reader := bytes.NewReader(suite.connMock.Written())
-
- for {
- header := make([]byte, tls.SizeHeader)
- if _, err := io.ReadFull(reader, header); err != nil {
- break
- }
-
- suite.Equal(byte(tls.TypeApplicationData), header[0])
- suite.Equal(tls.TLSVersion[:], header[tls.SizeRecordType:tls.SizeRecordType+tls.SizeVersion])
-
- length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
- suite.Greater(length, uint16(0))
-
- rec := make([]byte, length)
- _, err := io.ReadFull(reader, rec)
- suite.NoError(err)
-
- assembled.Write(rec)
- }
-
- suite.Equal(payload, assembled.Bytes())
- }
-
- func (suite *ConnTestSuite) TestWriteReturnsErrorAfterStop() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
- c.Stop()
-
- time.Sleep(10 * time.Millisecond)
-
- _, err := c.Write([]byte{1})
- suite.Error(err)
- }
-
- func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, errors.New("connection reset")).
- Maybe()
-
- c := suite.makeConn()
-
- _, _ = c.Write([]byte("data"))
-
- suite.Eventually(func() bool {
- _, err := c.Write([]byte{1})
- return err != nil
- }, 2*time.Second, time.Millisecond)
- }
-
- func (suite *ConnTestSuite) TestSyncWriteDataSent() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
- defer c.Stop()
-
- payload := []byte("sync hello")
- n, err := c.SyncWrite(payload)
- suite.NoError(err)
- suite.Equal(len(payload), n)
-
- // SyncWrite returns only after data is flushed to the wire.
- assembled := &bytes.Buffer{}
- reader := bytes.NewReader(suite.connMock.Written())
-
- for {
- header := make([]byte, tls.SizeHeader)
- if _, err := io.ReadFull(reader, header); err != nil {
- break
- }
-
- suite.Equal(byte(tls.TypeApplicationData), header[0])
-
- length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
- rec := make([]byte, length)
- _, err := io.ReadFull(reader, rec)
- suite.NoError(err)
-
- assembled.Write(rec)
- }
-
- suite.Equal(payload, assembled.Bytes())
- }
-
- func (suite *ConnTestSuite) TestSyncWriteDrainsBufferFirst() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
- defer c.Stop()
-
- // Buffer some data via async Write.
- _, err := c.Write([]byte("first"))
- suite.NoError(err)
-
- // SyncWrite must drain "first" before sending "second".
- n, err := c.SyncWrite([]byte("second"))
- suite.NoError(err)
- suite.Equal(6, n)
-
- // All data should be on the wire now.
- assembled := &bytes.Buffer{}
- reader := bytes.NewReader(suite.connMock.Written())
-
- for {
- header := make([]byte, tls.SizeHeader)
- if _, err := io.ReadFull(reader, header); err != nil {
- break
- }
-
- length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
- rec := make([]byte, length)
- _, err := io.ReadFull(reader, rec)
- suite.NoError(err)
-
- assembled.Write(rec)
- }
-
- suite.Equal([]byte("firstsecond"), assembled.Bytes())
- }
-
- func (suite *ConnTestSuite) TestSyncWriteBlocksAsyncWrite() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
- defer c.Stop()
-
- // Start SyncWrite — it holds exclusive lock.
- syncDone := make(chan struct{})
-
- go func() {
- defer close(syncDone)
- c.SyncWrite([]byte("exclusive")) //nolint: errcheck
- }()
-
- // Give SyncWrite time to acquire the lock.
- time.Sleep(10 * time.Millisecond)
-
- // Async Write should block until SyncWrite completes.
- writeDone := make(chan struct{})
-
- go func() {
- defer close(writeDone)
- c.Write([]byte("blocked")) //nolint: errcheck
- }()
-
- // SyncWrite should finish first.
- <-syncDone
-
- select {
- case <-writeDone:
- // Write completed after SyncWrite — correct.
- case <-time.After(2 * time.Second):
- suite.Fail("async Write did not unblock after SyncWrite completed")
- }
- }
-
- func (suite *ConnTestSuite) TestSyncWriteReturnsErrorAfterStop() {
- suite.connMock.
- On("Write", mock.AnythingOfType("[]uint8")).
- Return(0, nil).
- Maybe()
-
- c := suite.makeConn()
- c.Stop()
-
- time.Sleep(10 * time.Millisecond)
-
- _, err := c.SyncWrite([]byte("too late"))
- suite.Error(err)
- }
-
- func TestConn(t *testing.T) {
- t.Parallel()
- suite.Run(t, &ConnTestSuite{})
- }
|