Highly-opinionated (ex-bullshit-free) MTPROTO proxy for Telegram. If you use v1.0 or upgrade broke you proxy, please read the chapter Version 2
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

conns_internal_test.go 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package mtglib
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "errors"
  7. "io"
  8. "net"
  9. "testing"
  10. "time"
  11. "github.com/9seconds/mtg/v2/internal/testlib"
  12. "github.com/pires/go-proxyproto"
  13. "github.com/stretchr/testify/mock"
  14. "github.com/stretchr/testify/suite"
  15. )
  16. type ConnRewindBaseConn struct {
  17. testlib.EssentialsConnMock
  18. readBuffer bytes.Buffer
  19. }
  20. func (c *ConnRewindBaseConn) Read(p []byte) (int, error) {
  21. c.Called(p)
  22. return c.readBuffer.Read(p) //nolint: wrapcheck
  23. }
  24. type ConnTrafficTestSuite struct {
  25. suite.Suite
  26. eventStreamMock *EventStreamMock
  27. connMock *testlib.EssentialsConnMock
  28. conn io.ReadWriter
  29. }
  30. func (suite *ConnTrafficTestSuite) SetupTest() {
  31. suite.eventStreamMock = &EventStreamMock{}
  32. suite.connMock = &testlib.EssentialsConnMock{}
  33. suite.conn = connTraffic{
  34. Conn: suite.connMock,
  35. streamID: "CONNID",
  36. ctx: context.Background(),
  37. stream: suite.eventStreamMock,
  38. }
  39. }
  40. func (suite *ConnTrafficTestSuite) TearDownTest() {
  41. suite.eventStreamMock.AssertExpectations(suite.T())
  42. suite.connMock.AssertExpectations(suite.T())
  43. }
  44. func (suite *ConnTrafficTestSuite) TestReadOk() {
  45. suite.eventStreamMock.
  46. On("Send", mock.Anything, mock.Anything).
  47. Once().
  48. Run(func(args mock.Arguments) {
  49. evt, ok := args.Get(1).(EventTraffic)
  50. suite.True(ok)
  51. suite.Equal("CONNID", evt.StreamID())
  52. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  53. suite.EqualValues(10, evt.Traffic)
  54. suite.True(evt.IsRead)
  55. })
  56. suite.connMock.On("Read", mock.Anything).Once().Return(10, nil)
  57. n, err := suite.conn.Read(make([]byte, 10))
  58. suite.NoError(err)
  59. suite.Equal(10, n)
  60. }
  61. func (suite *ConnTrafficTestSuite) TestReadErr() { //nolint: dupl
  62. suite.eventStreamMock.
  63. On("Send", mock.Anything, mock.Anything).
  64. Once().
  65. Run(func(args mock.Arguments) {
  66. evt, ok := args.Get(1).(EventTraffic)
  67. suite.True(ok)
  68. suite.Equal("CONNID", evt.StreamID())
  69. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  70. suite.EqualValues(10, evt.Traffic)
  71. suite.True(evt.IsRead)
  72. })
  73. suite.connMock.On("Read", mock.Anything).Once().Return(10, io.EOF)
  74. n, err := suite.conn.Read(make([]byte, 10))
  75. suite.True(errors.Is(err, io.EOF))
  76. suite.Equal(10, n)
  77. }
  78. func (suite *ConnTrafficTestSuite) TestReadNothingOk() {
  79. suite.connMock.On("Read", mock.Anything).Once().Return(0, nil)
  80. n, err := suite.conn.Read(make([]byte, 10))
  81. suite.NoError(err)
  82. suite.Equal(0, n)
  83. }
  84. func (suite *ConnTrafficTestSuite) TestReadNothingErr() {
  85. suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
  86. n, err := suite.conn.Read(make([]byte, 10))
  87. suite.True(errors.Is(err, io.EOF))
  88. suite.Equal(0, n)
  89. }
  90. func (suite *ConnTrafficTestSuite) TestWriteOk() {
  91. suite.eventStreamMock.
  92. On("Send", mock.Anything, mock.Anything).
  93. Once().
  94. Run(func(args mock.Arguments) {
  95. evt, ok := args.Get(1).(EventTraffic)
  96. suite.True(ok)
  97. suite.Equal("CONNID", evt.StreamID())
  98. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  99. suite.EqualValues(10, evt.Traffic)
  100. suite.False(evt.IsRead)
  101. })
  102. suite.connMock.On("Write", mock.Anything).Once().Return(10, nil)
  103. n, err := suite.conn.Write(make([]byte, 10))
  104. suite.NoError(err)
  105. suite.Equal(10, n)
  106. }
  107. func (suite *ConnTrafficTestSuite) TestWriteErr() { //nolint: dupl
  108. suite.eventStreamMock.
  109. On("Send", mock.Anything, mock.Anything).
  110. Once().
  111. Run(func(args mock.Arguments) {
  112. evt, ok := args.Get(1).(EventTraffic)
  113. suite.True(ok)
  114. suite.Equal("CONNID", evt.StreamID())
  115. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  116. suite.EqualValues(10, evt.Traffic)
  117. suite.False(evt.IsRead)
  118. })
  119. suite.connMock.On("Write", mock.Anything).Once().Return(10, io.EOF)
  120. n, err := suite.conn.Write(make([]byte, 10))
  121. suite.True(errors.Is(err, io.EOF))
  122. suite.Equal(10, n)
  123. }
  124. func (suite *ConnTrafficTestSuite) TestWriteNothingOk() {
  125. suite.connMock.On("Write", mock.Anything).Once().Return(0, nil)
  126. n, err := suite.conn.Write(make([]byte, 10))
  127. suite.NoError(err)
  128. suite.Equal(0, n)
  129. }
  130. func (suite *ConnTrafficTestSuite) TestWriteNothingErr() {
  131. suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
  132. n, err := suite.conn.Write(make([]byte, 10))
  133. suite.True(errors.Is(err, io.EOF))
  134. suite.Equal(0, n)
  135. }
  136. type ConnRewindTestSuite struct {
  137. suite.Suite
  138. connMock *ConnRewindBaseConn
  139. conn *connRewind
  140. }
  141. func (suite *ConnRewindTestSuite) SetupTest() {
  142. suite.connMock = &ConnRewindBaseConn{}
  143. suite.conn = newConnRewind(suite.connMock)
  144. }
  145. func (suite *ConnRewindTestSuite) TearDownTest() {
  146. suite.connMock.AssertExpectations(suite.T())
  147. }
  148. func (suite *ConnRewindTestSuite) TestRead() {
  149. suite.connMock.On("Read", mock.Anything)
  150. suite.connMock.readBuffer.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
  151. buf := make([]byte, 2)
  152. n, err := suite.conn.Read(buf)
  153. suite.NoError(err)
  154. suite.Equal(2, n)
  155. suite.Equal([]byte{1, 2}, buf)
  156. n, err = suite.conn.Read(buf)
  157. suite.NoError(err)
  158. suite.Equal(2, n)
  159. suite.Equal([]byte{3, 4}, buf)
  160. suite.conn.Rewind()
  161. data, err := io.ReadAll(suite.conn)
  162. suite.NoError(err)
  163. suite.Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, data)
  164. }
  165. type ConnProxyProtocolTestSuite struct {
  166. suite.Suite
  167. sourceConnMock *testlib.EssentialsConnMock
  168. targetConnMock *testlib.EssentialsConnMock
  169. conn *connProxyProtocol
  170. }
  171. func (suite *ConnProxyProtocolTestSuite) SetupTest() {
  172. suite.sourceConnMock = &testlib.EssentialsConnMock{}
  173. suite.targetConnMock = &testlib.EssentialsConnMock{}
  174. localAddr := &net.TCPAddr{
  175. IP: net.ParseIP("127.0.0.1").To4(),
  176. }
  177. remoteAddr := &net.TCPAddr{
  178. IP: net.ParseIP("127.0.0.2").To4(),
  179. }
  180. suite.sourceConnMock.
  181. On("RemoteAddr").
  182. Return(localAddr)
  183. suite.targetConnMock.
  184. On("RemoteAddr").
  185. Maybe().
  186. Return(remoteAddr)
  187. suite.conn = newConnProxyProtocol(suite.sourceConnMock, suite.targetConnMock)
  188. }
  189. func (suite *ConnProxyProtocolTestSuite) TestRead() {
  190. value := []byte{1, 2, 3, 4, 5}
  191. toRead := make([]byte, len(value))
  192. suite.targetConnMock.
  193. On("Read", mock.AnythingOfType("[]uint8")).
  194. Once().
  195. Return(len(toRead), nil).
  196. Run(func(args mock.Arguments) {
  197. arr := args.Get(0).([]byte)
  198. copy(arr, value)
  199. })
  200. n, err := suite.conn.Read(toRead)
  201. suite.Equal(len(value), n)
  202. suite.NoError(err)
  203. suite.Equal(value, toRead)
  204. }
  205. func (suite *ConnProxyProtocolTestSuite) TestWrite() {
  206. value := []byte{1, 2, 3, 4, 5}
  207. buf := &bytes.Buffer{}
  208. bufReader := bufio.NewReader(buf)
  209. suite.targetConnMock.
  210. On("Write", mock.AnythingOfType("[]uint8")).
  211. Return(28, nil).
  212. Run(func(args mock.Arguments) {
  213. arr := args.Get(0).([]byte)
  214. buf.Write(arr)
  215. })
  216. _, err := suite.conn.Write(value)
  217. suite.NoError(err)
  218. header, err := proxyproto.Read(bufReader)
  219. suite.NoError(err)
  220. sourceAddr, destAddr, ok := header.TCPAddrs()
  221. suite.True(ok)
  222. suite.Equal(suite.sourceConnMock.RemoteAddr(), sourceAddr)
  223. suite.Equal(suite.targetConnMock.RemoteAddr(), destAddr)
  224. read, _ := io.ReadAll(bufReader)
  225. suite.Equal(value, read)
  226. _, err = suite.conn.Write(value)
  227. suite.NoError(err)
  228. read, _ = io.ReadAll(bufReader)
  229. suite.Equal(value, read)
  230. }
  231. func (suite *ConnProxyProtocolTestSuite) TearDownTest() {
  232. suite.sourceConnMock.AssertExpectations(suite.T())
  233. suite.targetConnMock.AssertExpectations(suite.T())
  234. }
  235. func TestConnTraffic(t *testing.T) {
  236. t.Parallel()
  237. suite.Run(t, &ConnTrafficTestSuite{})
  238. }
  239. func TestConnRewind(t *testing.T) {
  240. t.Parallel()
  241. suite.Run(t, &ConnRewindTestSuite{})
  242. }
  243. func TestConnProxyProtocol(t *testing.T) {
  244. t.Parallel()
  245. suite.Run(t, &ConnProxyProtocolTestSuite{})
  246. }