Преглед на файлове

Add SyncWrite method to doppel.Conn

tags/v2.2.0^2^2
9seconds преди 1 месец
родител
ревизия
33c0fa9bf7
променени са 2 файла, в които са добавени 175 реда и са изтрити 10 реда
  1. 45
    10
      mtglib/internal/doppel/conn.go
  2. 130
    0
      mtglib/internal/doppel/conn_test.go

+ 45
- 10
mtglib/internal/doppel/conn.go Целия файл

@@ -16,18 +16,48 @@ type Conn struct {
16 16
 }
17 17
 
18 18
 type connPayload struct {
19
-	ctx         context.Context
20
-	ctxCancel   context.CancelCauseFunc
21
-	clock       Clock
22
-	wg          sync.WaitGroup
23
-	writeLock   sync.Mutex
24
-	writeStream bytes.Buffer
19
+	ctx           context.Context
20
+	ctxCancel     context.CancelCauseFunc
21
+	clock         Clock
22
+	wg            sync.WaitGroup
23
+	syncWriteLock sync.RWMutex
24
+	writeStream   bytes.Buffer
25
+	writeCond     *sync.Cond
25 26
 }
26 27
 
27 28
 func (c Conn) Write(p []byte) (int, error) {
28
-	c.p.writeLock.Lock()
29
+	c.p.syncWriteLock.RLock()
30
+	defer c.p.syncWriteLock.RUnlock()
31
+
32
+	c.p.writeCond.L.Lock()
33
+	c.p.writeStream.Write(p)
34
+	c.p.writeCond.L.Unlock()
35
+
36
+	return len(p), context.Cause(c.p.ctx)
37
+}
38
+
39
+func (c Conn) SyncWrite(p []byte) (int, error) {
40
+	c.p.syncWriteLock.Lock()
41
+	defer c.p.syncWriteLock.Unlock()
42
+
43
+	c.p.writeCond.L.Lock()
44
+	// wait until buffer is exhausted
45
+	for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
46
+		c.p.writeCond.Wait()
47
+	}
29 48
 	c.p.writeStream.Write(p)
30
-	c.p.writeLock.Unlock()
49
+	c.p.writeCond.L.Unlock()
50
+
51
+	if err := context.Cause(c.p.ctx); err != nil {
52
+		return len(p), err
53
+	}
54
+
55
+	c.p.writeCond.L.Lock()
56
+	// wait until data will be sent
57
+	for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
58
+		c.p.writeCond.Wait()
59
+	}
60
+	c.p.writeCond.L.Unlock()
31 61
 
32 62
 	return len(p), context.Cause(c.p.ctx)
33 63
 }
@@ -39,6 +69,8 @@ func (c Conn) Start() {
39 69
 }
40 70
 
41 71
 func (c Conn) start() {
72
+	defer c.p.writeCond.Broadcast()
73
+
42 74
 	buf := [tls.MaxRecordSize]byte{}
43 75
 
44 76
 	for {
@@ -48,9 +80,9 @@ func (c Conn) start() {
48 80
 		case <-c.p.clock.tick:
49 81
 		}
50 82
 
51
-		c.p.writeLock.Lock()
83
+		c.p.writeCond.L.Lock()
52 84
 		n, err := c.p.writeStream.Read(buf[:c.p.clock.stats.Size()])
53
-		c.p.writeLock.Unlock()
85
+		c.p.writeCond.L.Unlock()
54 86
 
55 87
 		if n == 0 || err != nil {
56 88
 			continue
@@ -60,6 +92,8 @@ func (c Conn) start() {
60 92
 			c.p.ctxCancel(err)
61 93
 			return
62 94
 		}
95
+
96
+		c.p.writeCond.Signal()
63 97
 	}
64 98
 }
65 99
 
@@ -75,6 +109,7 @@ func NewConn(ctx context.Context, conn essentials.Conn, stats *Stats) Conn {
75 109
 		p: &connPayload{
76 110
 			ctx:       ctx,
77 111
 			ctxCancel: cancel,
112
+			writeCond: sync.NewCond(&sync.Mutex{}),
78 113
 			clock: Clock{
79 114
 				stats: stats,
80 115
 				tick:  make(chan struct{}),

+ 130
- 0
mtglib/internal/doppel/conn_test.go Целия файл

@@ -157,6 +157,136 @@ func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
157 157
 	}, 2*time.Second, time.Millisecond)
158 158
 }
159 159
 
160
+func (suite *ConnTestSuite) TestSyncWriteDataSent() {
161
+	suite.connMock.
162
+		On("Write", mock.AnythingOfType("[]uint8")).
163
+		Return(0, nil).
164
+		Maybe()
165
+
166
+	c := suite.makeConn()
167
+	defer c.Stop()
168
+
169
+	payload := []byte("sync hello")
170
+	n, err := c.SyncWrite(payload)
171
+	suite.NoError(err)
172
+	suite.Equal(len(payload), n)
173
+
174
+	// SyncWrite returns only after data is flushed to the wire.
175
+	assembled := &bytes.Buffer{}
176
+	reader := bytes.NewReader(suite.connMock.Written())
177
+
178
+	for {
179
+		header := make([]byte, tls.SizeHeader)
180
+		if _, err := io.ReadFull(reader, header); err != nil {
181
+			break
182
+		}
183
+
184
+		suite.Equal(byte(tls.TypeApplicationData), header[0])
185
+
186
+		length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
187
+		rec := make([]byte, length)
188
+		_, err := io.ReadFull(reader, rec)
189
+		suite.NoError(err)
190
+
191
+		assembled.Write(rec)
192
+	}
193
+
194
+	suite.Equal(payload, assembled.Bytes())
195
+}
196
+
197
+func (suite *ConnTestSuite) TestSyncWriteDrainsBufferFirst() {
198
+	suite.connMock.
199
+		On("Write", mock.AnythingOfType("[]uint8")).
200
+		Return(0, nil).
201
+		Maybe()
202
+
203
+	c := suite.makeConn()
204
+	defer c.Stop()
205
+
206
+	// Buffer some data via async Write.
207
+	_, err := c.Write([]byte("first"))
208
+	suite.NoError(err)
209
+
210
+	// SyncWrite must drain "first" before sending "second".
211
+	n, err := c.SyncWrite([]byte("second"))
212
+	suite.NoError(err)
213
+	suite.Equal(6, n)
214
+
215
+	// All data should be on the wire now.
216
+	assembled := &bytes.Buffer{}
217
+	reader := bytes.NewReader(suite.connMock.Written())
218
+
219
+	for {
220
+		header := make([]byte, tls.SizeHeader)
221
+		if _, err := io.ReadFull(reader, header); err != nil {
222
+			break
223
+		}
224
+
225
+		length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
226
+		rec := make([]byte, length)
227
+		_, err := io.ReadFull(reader, rec)
228
+		suite.NoError(err)
229
+
230
+		assembled.Write(rec)
231
+	}
232
+
233
+	suite.Equal([]byte("firstsecond"), assembled.Bytes())
234
+}
235
+
236
+func (suite *ConnTestSuite) TestSyncWriteBlocksAsyncWrite() {
237
+	suite.connMock.
238
+		On("Write", mock.AnythingOfType("[]uint8")).
239
+		Return(0, nil).
240
+		Maybe()
241
+
242
+	c := suite.makeConn()
243
+	defer c.Stop()
244
+
245
+	// Start SyncWrite — it holds exclusive lock.
246
+	syncDone := make(chan struct{})
247
+
248
+	go func() {
249
+		defer close(syncDone)
250
+		c.SyncWrite([]byte("exclusive"))
251
+	}()
252
+
253
+	// Give SyncWrite time to acquire the lock.
254
+	time.Sleep(10 * time.Millisecond)
255
+
256
+	// Async Write should block until SyncWrite completes.
257
+	writeDone := make(chan struct{})
258
+
259
+	go func() {
260
+		defer close(writeDone)
261
+		c.Write([]byte("blocked"))
262
+	}()
263
+
264
+	// SyncWrite should finish first.
265
+	<-syncDone
266
+
267
+	select {
268
+	case <-writeDone:
269
+		// Write completed after SyncWrite — correct.
270
+	case <-time.After(2 * time.Second):
271
+		suite.Fail("async Write did not unblock after SyncWrite completed")
272
+	}
273
+}
274
+
275
+func (suite *ConnTestSuite) TestSyncWriteReturnsErrorAfterStop() {
276
+	suite.connMock.
277
+		On("Write", mock.AnythingOfType("[]uint8")).
278
+		Return(0, nil).
279
+		Maybe()
280
+
281
+	c := suite.makeConn()
282
+	c.Stop()
283
+
284
+	time.Sleep(10 * time.Millisecond)
285
+
286
+	_, err := c.SyncWrite([]byte("too late"))
287
+	suite.Error(err)
288
+}
289
+
160 290
 func TestConn(t *testing.T) {
161 291
 	t.Parallel()
162 292
 	suite.Run(t, &ConnTestSuite{})

Loading…
Отказ
Запис