Highly-opinionated (ex-bullshit-free) MTPROTO proxy for Telegram. If you use v1.0 or upgrade broke you proxy, please read the chapter Version 2
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

conns_internal_test.go 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package mtglib
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "errors"
  7. "io"
  8. "net"
  9. "testing"
  10. "time"
  11. "github.com/9seconds/mtg/v2/internal/testlib"
  12. "github.com/pires/go-proxyproto"
  13. "github.com/stretchr/testify/mock"
  14. "github.com/stretchr/testify/suite"
  15. )
  16. type netTimeoutError struct{}
  17. func (e netTimeoutError) Error() string { return "i/o timeout" }
  18. func (e netTimeoutError) Timeout() bool { return true }
  19. func (e netTimeoutError) Temporary() bool { return true }
  20. type ConnRewindBaseConn struct {
  21. testlib.EssentialsConnMock
  22. readBuffer bytes.Buffer
  23. }
  24. func (c *ConnRewindBaseConn) Read(p []byte) (int, error) {
  25. c.Called(p)
  26. return c.readBuffer.Read(p) //nolint: wrapcheck
  27. }
  28. type ConnTrafficTestSuite struct {
  29. suite.Suite
  30. eventStreamMock *EventStreamMock
  31. connMock *testlib.EssentialsConnMock
  32. conn io.ReadWriter
  33. }
  34. func (suite *ConnTrafficTestSuite) SetupTest() {
  35. suite.eventStreamMock = &EventStreamMock{}
  36. suite.connMock = &testlib.EssentialsConnMock{}
  37. suite.conn = connTraffic{
  38. Conn: suite.connMock,
  39. streamID: "CONNID",
  40. ctx: context.Background(),
  41. stream: suite.eventStreamMock,
  42. }
  43. }
  44. func (suite *ConnTrafficTestSuite) TearDownTest() {
  45. suite.eventStreamMock.AssertExpectations(suite.T())
  46. suite.connMock.AssertExpectations(suite.T())
  47. }
  48. func (suite *ConnTrafficTestSuite) TestReadOk() {
  49. suite.eventStreamMock.
  50. On("Send", mock.Anything, mock.Anything).
  51. Once().
  52. Run(func(args mock.Arguments) {
  53. evt, ok := args.Get(1).(EventTraffic)
  54. suite.True(ok)
  55. suite.Equal("CONNID", evt.StreamID())
  56. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  57. suite.EqualValues(10, evt.Traffic)
  58. suite.True(evt.IsRead)
  59. })
  60. suite.connMock.On("Read", mock.Anything).Once().Return(10, nil)
  61. n, err := suite.conn.Read(make([]byte, 10))
  62. suite.NoError(err)
  63. suite.Equal(10, n)
  64. }
  65. func (suite *ConnTrafficTestSuite) TestReadErr() { //nolint: dupl
  66. suite.eventStreamMock.
  67. On("Send", mock.Anything, mock.Anything).
  68. Once().
  69. Run(func(args mock.Arguments) {
  70. evt, ok := args.Get(1).(EventTraffic)
  71. suite.True(ok)
  72. suite.Equal("CONNID", evt.StreamID())
  73. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  74. suite.EqualValues(10, evt.Traffic)
  75. suite.True(evt.IsRead)
  76. })
  77. suite.connMock.On("Read", mock.Anything).Once().Return(10, io.EOF)
  78. n, err := suite.conn.Read(make([]byte, 10))
  79. suite.True(errors.Is(err, io.EOF))
  80. suite.Equal(10, n)
  81. }
  82. func (suite *ConnTrafficTestSuite) TestReadNothingOk() {
  83. suite.connMock.On("Read", mock.Anything).Once().Return(0, nil)
  84. n, err := suite.conn.Read(make([]byte, 10))
  85. suite.NoError(err)
  86. suite.Equal(0, n)
  87. }
  88. func (suite *ConnTrafficTestSuite) TestReadNothingErr() {
  89. suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
  90. n, err := suite.conn.Read(make([]byte, 10))
  91. suite.True(errors.Is(err, io.EOF))
  92. suite.Equal(0, n)
  93. }
  94. func (suite *ConnTrafficTestSuite) TestWriteOk() {
  95. suite.eventStreamMock.
  96. On("Send", mock.Anything, mock.Anything).
  97. Once().
  98. Run(func(args mock.Arguments) {
  99. evt, ok := args.Get(1).(EventTraffic)
  100. suite.True(ok)
  101. suite.Equal("CONNID", evt.StreamID())
  102. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  103. suite.EqualValues(10, evt.Traffic)
  104. suite.False(evt.IsRead)
  105. })
  106. suite.connMock.On("Write", mock.Anything).Once().Return(10, nil)
  107. n, err := suite.conn.Write(make([]byte, 10))
  108. suite.NoError(err)
  109. suite.Equal(10, n)
  110. }
  111. func (suite *ConnTrafficTestSuite) TestWriteErr() { //nolint: dupl
  112. suite.eventStreamMock.
  113. On("Send", mock.Anything, mock.Anything).
  114. Once().
  115. Run(func(args mock.Arguments) {
  116. evt, ok := args.Get(1).(EventTraffic)
  117. suite.True(ok)
  118. suite.Equal("CONNID", evt.StreamID())
  119. suite.WithinDuration(time.Now(), evt.Timestamp(), time.Second)
  120. suite.EqualValues(10, evt.Traffic)
  121. suite.False(evt.IsRead)
  122. })
  123. suite.connMock.On("Write", mock.Anything).Once().Return(10, io.EOF)
  124. n, err := suite.conn.Write(make([]byte, 10))
  125. suite.True(errors.Is(err, io.EOF))
  126. suite.Equal(10, n)
  127. }
  128. func (suite *ConnTrafficTestSuite) TestWriteNothingOk() {
  129. suite.connMock.On("Write", mock.Anything).Once().Return(0, nil)
  130. n, err := suite.conn.Write(make([]byte, 10))
  131. suite.NoError(err)
  132. suite.Equal(0, n)
  133. }
  134. func (suite *ConnTrafficTestSuite) TestWriteNothingErr() {
  135. suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
  136. n, err := suite.conn.Write(make([]byte, 10))
  137. suite.True(errors.Is(err, io.EOF))
  138. suite.Equal(0, n)
  139. }
  140. type ConnRewindTestSuite struct {
  141. suite.Suite
  142. connMock *ConnRewindBaseConn
  143. conn *connRewind
  144. }
  145. func (suite *ConnRewindTestSuite) SetupTest() {
  146. suite.connMock = &ConnRewindBaseConn{}
  147. suite.conn = newConnRewind(suite.connMock)
  148. }
  149. func (suite *ConnRewindTestSuite) TearDownTest() {
  150. suite.connMock.AssertExpectations(suite.T())
  151. }
  152. func (suite *ConnRewindTestSuite) TestRead() {
  153. suite.connMock.On("Read", mock.Anything)
  154. suite.connMock.readBuffer.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
  155. buf := make([]byte, 2)
  156. n, err := suite.conn.Read(buf)
  157. suite.NoError(err)
  158. suite.Equal(2, n)
  159. suite.Equal([]byte{1, 2}, buf)
  160. n, err = suite.conn.Read(buf)
  161. suite.NoError(err)
  162. suite.Equal(2, n)
  163. suite.Equal([]byte{3, 4}, buf)
  164. suite.conn.Rewind()
  165. data, err := io.ReadAll(suite.conn)
  166. suite.NoError(err)
  167. suite.Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, data)
  168. }
  169. type ConnProxyProtocolTestSuite struct {
  170. suite.Suite
  171. sourceConnMock *testlib.EssentialsConnMock
  172. targetConnMock *testlib.EssentialsConnMock
  173. conn *connProxyProtocol
  174. }
  175. func (suite *ConnProxyProtocolTestSuite) SetupTest() {
  176. suite.sourceConnMock = &testlib.EssentialsConnMock{}
  177. suite.targetConnMock = &testlib.EssentialsConnMock{}
  178. localAddr := &net.TCPAddr{
  179. IP: net.ParseIP("127.0.0.1").To4(),
  180. }
  181. remoteAddr := &net.TCPAddr{
  182. IP: net.ParseIP("127.0.0.2").To4(),
  183. }
  184. suite.sourceConnMock.
  185. On("RemoteAddr").
  186. Return(localAddr)
  187. suite.targetConnMock.
  188. On("RemoteAddr").
  189. Maybe().
  190. Return(remoteAddr)
  191. suite.conn = newConnProxyProtocol(suite.sourceConnMock, suite.targetConnMock)
  192. }
  193. func (suite *ConnProxyProtocolTestSuite) TestRead() {
  194. value := []byte{1, 2, 3, 4, 5}
  195. toRead := make([]byte, len(value))
  196. suite.targetConnMock.
  197. On("Read", mock.AnythingOfType("[]uint8")).
  198. Once().
  199. Return(len(toRead), nil).
  200. Run(func(args mock.Arguments) {
  201. arr := args.Get(0).([]byte)
  202. copy(arr, value)
  203. })
  204. n, err := suite.conn.Read(toRead)
  205. suite.Equal(len(value), n)
  206. suite.NoError(err)
  207. suite.Equal(value, toRead)
  208. }
  209. func (suite *ConnProxyProtocolTestSuite) TestWrite() {
  210. value := []byte{1, 2, 3, 4, 5}
  211. buf := &bytes.Buffer{}
  212. bufReader := bufio.NewReader(buf)
  213. suite.targetConnMock.
  214. On("Write", mock.AnythingOfType("[]uint8")).
  215. Return(28, nil).
  216. Run(func(args mock.Arguments) {
  217. arr := args.Get(0).([]byte)
  218. buf.Write(arr)
  219. })
  220. _, err := suite.conn.Write(value)
  221. suite.NoError(err)
  222. header, err := proxyproto.Read(bufReader)
  223. suite.NoError(err)
  224. sourceAddr, destAddr, ok := header.TCPAddrs()
  225. suite.True(ok)
  226. suite.Equal(suite.sourceConnMock.RemoteAddr(), sourceAddr)
  227. suite.Equal(suite.targetConnMock.RemoteAddr(), destAddr)
  228. read, _ := io.ReadAll(bufReader)
  229. suite.Equal(value, read)
  230. _, err = suite.conn.Write(value)
  231. suite.NoError(err)
  232. read, _ = io.ReadAll(bufReader)
  233. suite.Equal(value, read)
  234. }
  235. func (suite *ConnProxyProtocolTestSuite) TearDownTest() {
  236. suite.sourceConnMock.AssertExpectations(suite.T())
  237. suite.targetConnMock.AssertExpectations(suite.T())
  238. }
  239. type IdleTrackerTestSuite struct {
  240. suite.Suite
  241. }
  242. func (suite *IdleTrackerTestSuite) TestNewNotIdle() {
  243. tracker := newIdleTracker(time.Second)
  244. suite.False(tracker.isIdle())
  245. }
  246. func (suite *IdleTrackerTestSuite) TestIdleAfterTimeout() {
  247. tracker := newIdleTracker(10 * time.Millisecond)
  248. time.Sleep(20 * time.Millisecond)
  249. suite.True(tracker.isIdle())
  250. }
  251. func (suite *IdleTrackerTestSuite) TestTouchResetsIdle() {
  252. tracker := newIdleTracker(50 * time.Millisecond)
  253. time.Sleep(30 * time.Millisecond)
  254. tracker.touch()
  255. suite.False(tracker.isIdle())
  256. }
  257. type ConnIdleTimeoutTestSuite struct {
  258. suite.Suite
  259. connMock *testlib.EssentialsConnMock
  260. tracker *idleTracker
  261. conn connIdleTimeout
  262. }
  263. func (suite *ConnIdleTimeoutTestSuite) SetupTest() {
  264. suite.connMock = &testlib.EssentialsConnMock{}
  265. suite.tracker = newIdleTracker(time.Second)
  266. suite.conn = connIdleTimeout{
  267. Conn: suite.connMock,
  268. tracker: suite.tracker,
  269. }
  270. }
  271. func (suite *ConnIdleTimeoutTestSuite) TearDownTest() {
  272. suite.connMock.AssertExpectations(suite.T())
  273. }
  274. func (suite *ConnIdleTimeoutTestSuite) TestReadOk() {
  275. suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
  276. suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
  277. n, err := suite.conn.Read(make([]byte, 10))
  278. suite.NoError(err)
  279. suite.Equal(5, n)
  280. }
  281. func (suite *ConnIdleTimeoutTestSuite) TestReadNonTimeoutErr() {
  282. suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
  283. suite.connMock.On("Read", mock.Anything).Once().Return(0, io.EOF)
  284. n, err := suite.conn.Read(make([]byte, 10))
  285. suite.True(errors.Is(err, io.EOF))
  286. suite.Equal(0, n)
  287. }
  288. func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutRetriesWhenNotIdle() {
  289. suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
  290. suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
  291. suite.connMock.On("Read", mock.Anything).Once().Return(5, nil)
  292. n, err := suite.conn.Read(make([]byte, 10))
  293. suite.NoError(err)
  294. suite.Equal(5, n)
  295. }
  296. func (suite *ConnIdleTimeoutTestSuite) TestReadTimeoutClosesWhenIdle() {
  297. suite.tracker = newIdleTracker(time.Millisecond)
  298. suite.conn = connIdleTimeout{
  299. Conn: suite.connMock,
  300. tracker: suite.tracker,
  301. }
  302. time.Sleep(5 * time.Millisecond)
  303. suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
  304. suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
  305. n, err := suite.conn.Read(make([]byte, 10))
  306. suite.Equal(0, n)
  307. netErr, ok := err.(net.Error) //nolint: errorlint
  308. suite.True(ok)
  309. suite.True(netErr.Timeout())
  310. }
  311. func (suite *ConnIdleTimeoutTestSuite) TestSharedTrackerPreventsFalseTimeout() {
  312. connMock2 := &testlib.EssentialsConnMock{}
  313. conn2 := connIdleTimeout{
  314. Conn: connMock2,
  315. tracker: suite.tracker,
  316. }
  317. connMock2.On("SetWriteDeadline", mock.Anything).Return(nil)
  318. connMock2.On("Write", mock.Anything).Once().Return(5, nil)
  319. _, _ = conn2.Write(make([]byte, 5))
  320. suite.connMock.On("SetReadDeadline", mock.Anything).Return(nil)
  321. suite.connMock.On("Read", mock.Anything).Once().Return(0, netTimeoutError{})
  322. suite.connMock.On("Read", mock.Anything).Once().Return(3, nil)
  323. n, err := suite.conn.Read(make([]byte, 10))
  324. suite.NoError(err)
  325. suite.Equal(3, n)
  326. connMock2.AssertExpectations(suite.T())
  327. }
  328. func (suite *ConnIdleTimeoutTestSuite) TestWriteOk() {
  329. suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
  330. suite.connMock.On("Write", mock.Anything).Once().Return(5, nil)
  331. n, err := suite.conn.Write(make([]byte, 5))
  332. suite.NoError(err)
  333. suite.Equal(5, n)
  334. }
  335. func (suite *ConnIdleTimeoutTestSuite) TestWriteErr() {
  336. suite.connMock.On("SetWriteDeadline", mock.Anything).Return(nil)
  337. suite.connMock.On("Write", mock.Anything).Once().Return(0, io.EOF)
  338. n, err := suite.conn.Write(make([]byte, 5))
  339. suite.True(errors.Is(err, io.EOF))
  340. suite.Equal(0, n)
  341. }
  342. func TestConnTraffic(t *testing.T) {
  343. t.Parallel()
  344. suite.Run(t, &ConnTrafficTestSuite{})
  345. }
  346. func TestConnRewind(t *testing.T) {
  347. t.Parallel()
  348. suite.Run(t, &ConnRewindTestSuite{})
  349. }
  350. func TestConnProxyProtocol(t *testing.T) {
  351. t.Parallel()
  352. suite.Run(t, &ConnProxyProtocolTestSuite{})
  353. }
  354. func TestIdleTracker(t *testing.T) {
  355. t.Parallel()
  356. suite.Run(t, &IdleTrackerTestSuite{})
  357. }
  358. func TestConnIdleTimeout(t *testing.T) {
  359. t.Parallel()
  360. suite.Run(t, &ConnIdleTimeoutTestSuite{})
  361. }