Просмотр исходного кода

Merge remote-tracking branch 'origin/stable' into v2

tags/v2.2.2
9seconds 1 месяц назад
Родитель
Сommit
fc72de9e39

+ 32
- 46
mtglib/internal/doppel/conn.go Просмотреть файл

16
 }
16
 }
17
 
17
 
18
 type connPayload struct {
18
 type connPayload struct {
19
-	ctx           context.Context
20
-	ctxCancel     context.CancelCauseFunc
21
-	clock         Clock
22
-	wg            sync.WaitGroup
23
-	syncWriteLock sync.RWMutex
24
-	writeStream   bytes.Buffer
25
-	writeCond     *sync.Cond
19
+	ctx         context.Context
20
+	ctxCancel   context.CancelCauseFunc
21
+	clock       Clock
22
+	wg          sync.WaitGroup
23
+	writeStream bytes.Buffer
24
+	writtenCond sync.Cond
25
+	done        bool
26
 }
26
 }
27
 
27
 
28
 func (c Conn) Write(p []byte) (int, error) {
28
 func (c Conn) Write(p []byte) (int, error) {
29
-	c.p.syncWriteLock.RLock()
30
-	defer c.p.syncWriteLock.RUnlock()
31
-
32
-	c.p.writeCond.L.Lock()
33
-	c.p.writeStream.Write(p)
34
-	c.p.writeCond.L.Unlock()
35
-
36
-	return len(p), context.Cause(c.p.ctx)
37
-}
38
-
39
-func (c Conn) SyncWrite(p []byte) (int, error) {
40
-	c.p.syncWriteLock.Lock()
41
-	defer c.p.syncWriteLock.Unlock()
42
-
43
-	c.p.writeCond.L.Lock()
44
-	// wait until buffer is exhausted
45
-	for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
46
-		c.p.writeCond.Wait()
29
+	if len(p) == 0 {
30
+		return 0, context.Cause(c.p.ctx)
47
 	}
31
 	}
48
-	c.p.writeStream.Write(p)
49
-	c.p.writeCond.L.Unlock()
50
 
32
 
51
-	if err := context.Cause(c.p.ctx); err != nil {
52
-		return len(p), err
53
-	}
33
+	c.p.writtenCond.L.Lock()
34
+	c.p.writeStream.Write(p)
35
+	c.p.writtenCond.L.Unlock()
54
 
36
 
55
-	c.p.writeCond.L.Lock()
56
-	// wait until data will be sent
57
-	for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
58
-		c.p.writeCond.Wait()
59
-	}
60
-	c.p.writeCond.L.Unlock()
37
+	c.p.writtenCond.Signal()
61
 
38
 
62
 	return len(p), context.Cause(c.p.ctx)
39
 	return len(p), context.Cause(c.p.ctx)
63
 }
40
 }
69
 }
46
 }
70
 
47
 
71
 func (c Conn) start() {
48
 func (c Conn) start() {
72
-	defer c.p.writeCond.Broadcast()
73
-
74
 	buf := [tls.MaxRecordSize]byte{}
49
 	buf := [tls.MaxRecordSize]byte{}
75
 
50
 
76
 	for {
51
 	for {
80
 		case <-c.p.clock.tick:
55
 		case <-c.p.clock.tick:
81
 		}
56
 		}
82
 
57
 
83
-		c.p.writeCond.L.Lock()
84
-		n, err := c.p.writeStream.Read(buf[:c.p.clock.stats.Size()])
85
-		c.p.writeCond.L.Unlock()
58
+		size := c.p.clock.stats.Size()
59
+
60
+		c.p.writtenCond.L.Lock()
61
+		for c.p.writeStream.Len() == 0 && !c.p.done {
62
+			c.p.writtenCond.Wait()
63
+		}
64
+		n, _ := c.p.writeStream.Read(buf[tls.SizeHeader : tls.SizeHeader+size])
65
+		c.p.writtenCond.L.Unlock()
86
 
66
 
87
-		if n == 0 || err != nil {
67
+		if n == 0 {
88
 			continue
68
 			continue
89
 		}
69
 		}
90
 
70
 
91
-		if err := tls.WriteRecord(c.Conn, buf[:n]); err != nil {
71
+		if err := tls.WriteRecordInPlace(c.Conn, buf[:], n); err != nil {
92
 			c.p.ctxCancel(err)
72
 			c.p.ctxCancel(err)
93
 			return
73
 			return
94
 		}
74
 		}
95
-
96
-		c.p.writeCond.Signal()
97
 	}
75
 	}
98
 }
76
 }
99
 
77
 
100
 func (c Conn) Stop() {
78
 func (c Conn) Stop() {
101
 	c.p.ctxCancel(nil)
79
 	c.p.ctxCancel(nil)
80
+
81
+	c.p.writtenCond.L.Lock()
82
+	c.p.done = true
83
+	c.p.writtenCond.L.Unlock()
84
+	c.p.writtenCond.Broadcast()
85
+
102
 	c.p.wg.Wait()
86
 	c.p.wg.Wait()
103
 }
87
 }
104
 
88
 
109
 		p: &connPayload{
93
 		p: &connPayload{
110
 			ctx:       ctx,
94
 			ctx:       ctx,
111
 			ctxCancel: cancel,
95
 			ctxCancel: cancel,
112
-			writeCond: sync.NewCond(&sync.Mutex{}),
96
+			writtenCond: sync.Cond{
97
+				L: &sync.Mutex{},
98
+			},
113
 			clock: Clock{
99
 			clock: Clock{
114
 				stats: stats,
100
 				stats: stats,
115
 				tick:  make(chan struct{}),
101
 				tick:  make(chan struct{}),

+ 30
- 129
mtglib/internal/doppel/conn_test.go Просмотреть файл

141
 	suite.Error(err)
141
 	suite.Error(err)
142
 }
142
 }
143
 
143
 
144
-func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
145
-	suite.connMock.
146
-		On("Write", mock.AnythingOfType("[]uint8")).
147
-		Return(0, errors.New("connection reset")).
148
-		Maybe()
149
-
150
-	c := suite.makeConn()
151
-
152
-	_, _ = c.Write([]byte("data"))
153
-
154
-	suite.Eventually(func() bool {
155
-		_, err := c.Write([]byte{1})
156
-		return err != nil
157
-	}, 2*time.Second, time.Millisecond)
158
-}
159
-
160
-func (suite *ConnTestSuite) TestSyncWriteDataSent() {
161
-	suite.connMock.
162
-		On("Write", mock.AnythingOfType("[]uint8")).
163
-		Return(0, nil).
164
-		Maybe()
165
-
166
-	c := suite.makeConn()
167
-	defer c.Stop()
168
-
169
-	payload := []byte("sync hello")
170
-	n, err := c.SyncWrite(payload)
171
-	suite.NoError(err)
172
-	suite.Equal(len(payload), n)
173
-
174
-	// SyncWrite returns only after data is flushed to the wire.
175
-	assembled := &bytes.Buffer{}
176
-	reader := bytes.NewReader(suite.connMock.Written())
177
-
178
-	for {
179
-		header := make([]byte, tls.SizeHeader)
180
-		if _, err := io.ReadFull(reader, header); err != nil {
181
-			break
182
-		}
183
-
184
-		suite.Equal(byte(tls.TypeApplicationData), header[0])
185
-
186
-		length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
187
-		rec := make([]byte, length)
188
-		_, err := io.ReadFull(reader, rec)
189
-		suite.NoError(err)
190
-
191
-		assembled.Write(rec)
192
-	}
193
-
194
-	suite.Equal(payload, assembled.Bytes())
195
-}
196
-
197
-func (suite *ConnTestSuite) TestSyncWriteDrainsBufferFirst() {
198
-	suite.connMock.
199
-		On("Write", mock.AnythingOfType("[]uint8")).
200
-		Return(0, nil).
201
-		Maybe()
202
-
203
-	c := suite.makeConn()
204
-	defer c.Stop()
205
-
206
-	// Buffer some data via async Write.
207
-	_, err := c.Write([]byte("first"))
208
-	suite.NoError(err)
209
-
210
-	// SyncWrite must drain "first" before sending "second".
211
-	n, err := c.SyncWrite([]byte("second"))
212
-	suite.NoError(err)
213
-	suite.Equal(6, n)
214
-
215
-	// All data should be on the wire now.
216
-	assembled := &bytes.Buffer{}
217
-	reader := bytes.NewReader(suite.connMock.Written())
218
-
219
-	for {
220
-		header := make([]byte, tls.SizeHeader)
221
-		if _, err := io.ReadFull(reader, header); err != nil {
222
-			break
223
-		}
224
-
225
-		length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
226
-		rec := make([]byte, length)
227
-		_, err := io.ReadFull(reader, rec)
228
-		suite.NoError(err)
229
-
230
-		assembled.Write(rec)
231
-	}
232
-
233
-	suite.Equal([]byte("firstsecond"), assembled.Bytes())
234
-}
235
-
236
-func (suite *ConnTestSuite) TestSyncWriteBlocksAsyncWrite() {
144
+func (suite *ConnTestSuite) TestStopDoesNotDeadlockWhenStartIsWaiting() {
237
 	suite.connMock.
145
 	suite.connMock.
238
 		On("Write", mock.AnythingOfType("[]uint8")).
146
 		On("Write", mock.AnythingOfType("[]uint8")).
239
 		Return(0, nil).
147
 		Return(0, nil).
240
 		Maybe()
148
 		Maybe()
241
 
149
 
242
-	c := suite.makeConn()
243
-	defer c.Stop()
244
-
245
-	// Start SyncWrite — it holds exclusive lock.
246
-	syncDone := make(chan struct{})
247
-
248
-	go func() {
249
-		defer close(syncDone)
250
-		c.SyncWrite([]byte("exclusive")) //nolint: errcheck
251
-	}()
252
-
253
-	// Give SyncWrite time to acquire the lock.
254
-	time.Sleep(10 * time.Millisecond)
255
-
256
-	// Async Write should block until SyncWrite completes.
257
-	writeDone := make(chan struct{})
258
-
259
-	go func() {
260
-		defer close(writeDone)
261
-		c.Write([]byte("blocked")) //nolint: errcheck
262
-	}()
263
-
264
-	// SyncWrite should finish first.
265
-	<-syncDone
266
-
267
-	select {
268
-	case <-writeDone:
269
-		// Write completed after SyncWrite — correct.
270
-	case <-time.After(2 * time.Second):
271
-		suite.Fail("async Write did not unblock after SyncWrite completed")
150
+	for range 100 {
151
+		func() {
152
+			ctx, cancel := context.WithCancel(suite.ctx)
153
+			defer cancel()
154
+
155
+			c := NewConn(ctx, suite.connMock, &Stats{
156
+				k:      2.0,
157
+				lambda: 0.01,
158
+			})
159
+
160
+			done := make(chan struct{})
161
+			go func() {
162
+				defer close(done)
163
+				c.Stop()
164
+			}()
165
+
166
+			select {
167
+			case <-done:
168
+			case <-time.After(2 * time.Second):
169
+				suite.Fail("Stop() deadlocked: start() likely stuck in writtenCond.Wait()")
170
+			}
171
+		}()
272
 	}
172
 	}
273
 }
173
 }
274
 
174
 
275
-func (suite *ConnTestSuite) TestSyncWriteReturnsErrorAfterStop() {
175
+func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
276
 	suite.connMock.
176
 	suite.connMock.
277
 		On("Write", mock.AnythingOfType("[]uint8")).
177
 		On("Write", mock.AnythingOfType("[]uint8")).
278
-		Return(0, nil).
178
+		Return(0, errors.New("connection reset")).
279
 		Maybe()
179
 		Maybe()
280
 
180
 
281
 	c := suite.makeConn()
181
 	c := suite.makeConn()
282
-	c.Stop()
283
 
182
 
284
-	time.Sleep(10 * time.Millisecond)
183
+	_, _ = c.Write([]byte("data"))
285
 
184
 
286
-	_, err := c.SyncWrite([]byte("too late"))
287
-	suite.Error(err)
185
+	suite.Eventually(func() bool {
186
+		_, err := c.Write([]byte{1})
187
+		return err != nil
188
+	}, 2*time.Second, time.Millisecond)
288
 }
189
 }
289
 
190
 
290
 func TestConn(t *testing.T) {
191
 func TestConn(t *testing.T) {

+ 2
- 1
mtglib/internal/doppel/ganger.go Просмотреть файл

98
 			g.durations = append(g.durations, durations...)
98
 			g.durations = append(g.durations, durations...)
99
 
99
 
100
 			if len(g.durations) > DoppelGangerMaxDurations {
100
 			if len(g.durations) > DoppelGangerMaxDurations {
101
-				g.durations = g.durations[len(g.durations)-DoppelGangerMaxDurations:]
101
+				copy(g.durations, g.durations[len(g.durations)-DoppelGangerMaxDurations:])
102
+				g.durations = g.durations[:DoppelGangerMaxDurations]
102
 			}
103
 			}
103
 
104
 
104
 			if len(g.durations) < MinDurationsToCalculate {
105
 			if len(g.durations) < MinDurationsToCalculate {

+ 0
- 2
mtglib/internal/tls/conn.go Просмотреть файл

34
 
34
 
35
 type connPayload struct {
35
 type connPayload struct {
36
 	readBuf      bytes.Buffer
36
 	readBuf      bytes.Buffer
37
-	writeBuf     bytes.Buffer
38
 	connBuffered *bufio.Reader
37
 	connBuffered *bufio.Reader
39
 	read         bool
38
 	read         bool
40
 	write        bool
39
 	write        bool
80
 	}
79
 	}
81
 
80
 
82
 	newConn.p.readBuf.Grow(DefaultBufferSize)
81
 	newConn.p.readBuf.Grow(DefaultBufferSize)
83
-	newConn.p.writeBuf.Grow(DefaultBufferSize)
84
 
82
 
85
 	return newConn
83
 	return newConn
86
 }
84
 }

+ 14
- 10
mtglib/internal/tls/utils.go Просмотреть файл

29
 
29
 
30
 func WriteRecord(w io.Writer, payload []byte) error {
30
 func WriteRecord(w io.Writer, payload []byte) error {
31
 	buf := [MaxRecordSize]byte{}
31
 	buf := [MaxRecordSize]byte{}
32
-	buf[0] = TypeApplicationData
33
-
34
-	bufV := buf[SizeRecordType:]
35
-	copy(bufV[:SizeVersion], TLSVersion[:])
32
+	copy(buf[SizeHeader:], payload)
36
 
33
 
37
-	bufS := bufV[SizeVersion:]
38
-	binary.BigEndian.PutUint16(bufS[:SizeSize], uint16(len(payload)))
34
+	return WriteRecordInPlace(w, buf[:], len(payload))
35
+}
39
 
36
 
40
-	bufP := buf[SizeHeader:]
41
-	if n := copy(bufP, payload); n != len(payload) {
42
-		return fmt.Errorf("copied %d bytes of payload instead of %d", n, len(payload))
37
+func WriteRecordInPlace(w io.Writer, buf []byte, payloadLen int) error {
38
+	if payloadLen > MaxRecordPayloadSize {
39
+		return fmt.Errorf("payload %d exceeds max %d", payloadLen, MaxRecordPayloadSize)
43
 	}
40
 	}
44
 
41
 
45
-	_, err := w.Write(buf[:SizeHeader+len(payload)])
42
+	buf[0] = TypeApplicationData
43
+	copy(buf[SizeRecordType:SizeRecordType+SizeVersion], TLSVersion[:])
44
+	binary.BigEndian.PutUint16(
45
+		buf[SizeRecordType+SizeVersion:SizeRecordType+SizeVersion+SizeSize],
46
+		uint16(payloadLen),
47
+	)
48
+
49
+	_, err := w.Write(buf[:SizeHeader+payloadLen])
46
 
50
 
47
 	return err
51
 	return err
48
 }
52
 }

+ 78
- 0
mtglib/internal/tls/utils_test.go Просмотреть файл

119
 	suite.Error(err)
119
 	suite.Error(err)
120
 }
120
 }
121
 
121
 
122
+func (suite *UtilsTestSuite) TestWriteRecordInPlace() {
123
+	payload := []byte("hello in-place")
124
+
125
+	var buf [MaxRecordSize]byte
126
+	copy(buf[SizeHeader:], payload)
127
+
128
+	err := WriteRecordInPlace(suite.dst, buf[:], len(payload))
129
+	suite.NoError(err)
130
+
131
+	written := suite.dst.Bytes()
132
+	suite.Equal(byte(TypeApplicationData), written[0])
133
+	suite.Equal(TLSVersion[:], written[SizeRecordType:SizeRecordType+SizeVersion])
134
+
135
+	length := binary.BigEndian.Uint16(written[SizeRecordType+SizeVersion:])
136
+	suite.Equal(uint16(len(payload)), length)
137
+	suite.Equal(payload, written[SizeHeader:])
138
+}
139
+
140
+func (suite *UtilsTestSuite) TestWriteRecordInPlaceRoundTrip() {
141
+	payload := []byte("round trip in-place")
142
+
143
+	var buf [MaxRecordSize]byte
144
+	copy(buf[SizeHeader:], payload)
145
+
146
+	var wire bytes.Buffer
147
+
148
+	err := WriteRecordInPlace(&wire, buf[:], len(payload))
149
+	suite.NoError(err)
150
+
151
+	var recovered bytes.Buffer
152
+
153
+	recordType, length, err := ReadRecord(&wire, &recovered)
154
+	suite.NoError(err)
155
+	suite.Equal(byte(TypeApplicationData), recordType)
156
+	suite.Equal(int64(len(payload)), length)
157
+	suite.Equal(payload, recovered.Bytes())
158
+}
159
+
160
+func (suite *UtilsTestSuite) TestWriteRecordInPlacePayloadTooLarge() {
161
+	var buf [MaxRecordSize]byte
162
+
163
+	err := WriteRecordInPlace(suite.dst, buf[:], MaxRecordPayloadSize+1)
164
+	suite.Error(err)
165
+}
166
+
167
+func (suite *UtilsTestSuite) TestWriteRecordInPlacePropagatesError() {
168
+	m := &WriterMock{}
169
+	m.
170
+		On("Write", mock.AnythingOfType("[]uint8")).
171
+		Once().
172
+		Return(0, errors.New("disk full"))
173
+
174
+	var buf [MaxRecordSize]byte
175
+	copy(buf[SizeHeader:], []byte("data"))
176
+
177
+	err := WriteRecordInPlace(m, buf[:], 4)
178
+	suite.Error(err)
179
+
180
+	m.AssertExpectations(suite.T())
181
+}
182
+
183
+func (suite *UtilsTestSuite) TestWriteRecordInPlaceMatchesWriteRecord() {
184
+	payload := []byte("equivalence check")
185
+
186
+	var legacy bytes.Buffer
187
+	err := WriteRecord(&legacy, payload)
188
+	suite.NoError(err)
189
+
190
+	var buf [MaxRecordSize]byte
191
+	copy(buf[SizeHeader:], payload)
192
+
193
+	var inPlace bytes.Buffer
194
+	err = WriteRecordInPlace(&inPlace, buf[:], len(payload))
195
+	suite.NoError(err)
196
+
197
+	suite.Equal(legacy.Bytes(), inPlace.Bytes())
198
+}
199
+
122
 func TestUtils(t *testing.T) {
200
 func TestUtils(t *testing.T) {
123
 	t.Parallel()
201
 	t.Parallel()
124
 	suite.Run(t, &UtilsTestSuite{})
202
 	suite.Run(t, &UtilsTestSuite{})

+ 8
- 1
mtglib/proxy.go Просмотреть файл

259
 		ctx:      ctx,
259
 		ctx:      ctx,
260
 	}
260
 	}
261
 
261
 
262
+	telegramHost, _, err := net.SplitHostPort(foundAddr.Address)
263
+	if err != nil {
264
+		conn.Close() //nolint: errcheck
265
+
266
+		return fmt.Errorf("cannot parse telegram address %s: %w", foundAddr.Address, err)
267
+	}
268
+
262
 	p.eventStream.Send(ctx,
269
 	p.eventStream.Send(ctx,
263
 		NewEventConnectedToDC(ctx.streamID,
270
 		NewEventConnectedToDC(ctx.streamID,
264
-			conn.RemoteAddr().(*net.TCPAddr).IP, //nolint: forcetypeassert
271
+			net.ParseIP(telegramHost),
265
 			ctx.dc),
272
 			ctx.dc),
266
 	)
273
 	)
267
 
274
 

Загрузка…
Отмена
Сохранить