Browse Source

Add syncPair

tags/v2.1.3^2
9seconds 4 years ago
parent
commit
a5e59d9ef7

+ 5
- 1
mtglib/internal/relay/init.go View File

@@ -1,7 +1,11 @@
1 1
 package relay
2 2
 
3
+import "time"
4
+
3 5
 const (
4
-	bufferSize = 32 * 1024
6
+	copyBufferSize   = 32 * 1024
7
+	writerBufferSize = 2 * copyBufferSize
8
+	readTimeout      = 10 * time.Millisecond
5 9
 )
6 10
 
7 11
 type Logger interface {

+ 20
- 14
mtglib/internal/relay/pools.go View File

@@ -1,25 +1,31 @@
1 1
 package relay
2 2
 
3
-import "sync"
3
+import (
4
+	"bufio"
5
+	"io"
6
+	"net"
7
+	"sync"
8
+)
4 9
 
5
-type eastWest struct {
6
-	east []byte
7
-	west []byte
8
-}
9
-
10
-var eastWestPool = sync.Pool{
10
+var syncPairPool = sync.Pool{
11 11
 	New: func() interface{} {
12
-		return &eastWest{
13
-			east: make([]byte, bufferSize),
14
-			west: make([]byte, bufferSize),
12
+		return &syncPair{
13
+			writer:  bufio.NewWriterSize(nil, writerBufferSize),
14
+			copyBuf: make([]byte, copyBufferSize),
15 15
 		}
16 16
 	},
17 17
 }
18 18
 
19
-func acquireEastWest() *eastWest {
20
-	return eastWestPool.Get().(*eastWest)
19
+func acquireSyncPair(reader net.Conn, writer io.Writer) *syncPair {
20
+	sp := syncPairPool.Get().(*syncPair) // nolint: forcetypeassert
21
+	sp.writer.Reset(writer)
22
+	sp.reader = reader
23
+
24
+	return sp
21 25
 }
22 26
 
23
-func releaseEastWest(ew *eastWest) {
24
-	eastWestPool.Put(ew)
27
+func releaseSyncPair(sp *syncPair) {
28
+	sp.writer.Reset(nil)
29
+	sp.reader = nil
30
+	syncPairPool.Put(sp)
25 31
 }

+ 9
- 13
mtglib/internal/relay/relay.go View File

@@ -2,14 +2,11 @@ package relay
2 2
 
3 3
 import (
4 4
 	"context"
5
-	"io"
5
+	"net"
6 6
 	"sync"
7 7
 )
8 8
 
9
-func Relay(ctx context.Context, log Logger, telegramConn, clientConn io.ReadWriteCloser) {
10
-	defer telegramConn.Close()
11
-	defer clientConn.Close()
12
-
9
+func Relay(ctx context.Context, log Logger, telegramConn, clientConn net.Conn) {
13 10
 	ctx, cancel := context.WithCancel(ctx)
14 11
 	defer cancel()
15 12
 
@@ -19,26 +16,25 @@ func Relay(ctx context.Context, log Logger, telegramConn, clientConn io.ReadWrit
19 16
 		clientConn.Close()
20 17
 	}()
21 18
 
22
-	buffers := acquireEastWest()
23
-	defer releaseEastWest(buffers)
24
-
25 19
 	wg := &sync.WaitGroup{}
26 20
 	wg.Add(2) // nolint: gomnd
27 21
 
28
-	go pump(log, telegramConn, clientConn, wg, buffers.east, "east -> west")
22
+	go pump(log, telegramConn, clientConn, wg, "client -> telegram")
29 23
 
30
-	pump(log, clientConn, telegramConn, wg, buffers.west, "west -> east")
24
+	pump(log, clientConn, telegramConn, wg, "telegram -> client")
31 25
 
32 26
 	wg.Wait()
33 27
 }
34 28
 
35
-func pump(log Logger, src io.ReadCloser, dst io.WriteCloser, wg *sync.WaitGroup,
36
-	buf []byte, direction string) {
29
+func pump(log Logger, src, dst net.Conn, wg *sync.WaitGroup, direction string) {
37 30
 	defer wg.Done()
38 31
 	defer src.Close()
39 32
 	defer dst.Close()
40 33
 
41
-	if n, err := io.CopyBuffer(dst, src, buf); err != nil {
34
+	syncer := acquireSyncPair(src, dst)
35
+	defer releaseSyncPair(syncer)
36
+
37
+	if n, err := syncer.Sync(); err != nil {
42 38
 		log.Printf("cannot pump %s (written %d bytes): %w", direction, n, err)
43 39
 	}
44 40
 }

+ 54
- 0
mtglib/internal/relay/sync_pair.go View File

@@ -0,0 +1,54 @@
1
+package relay
2
+
3
+import (
4
+	"bufio"
5
+	"errors"
6
+	"fmt"
7
+	"io"
8
+	"net"
9
+	"os"
10
+	"time"
11
+)
12
+
13
+type syncPair struct {
14
+	writer  *bufio.Writer
15
+	copyBuf []byte
16
+
17
+	reader net.Conn
18
+}
19
+
20
+func (s *syncPair) Sync() (int64, error) {
21
+	return io.CopyBuffer(s, s, s.copyBuf) // nolint: wrapcheck
22
+}
23
+
24
+func (s *syncPair) Read(p []byte) (int, error) {
25
+	n, err := s.readBlocking(p, false)
26
+
27
+	if errors.Is(err, os.ErrDeadlineExceeded) {
28
+		if err := s.writer.Flush(); err != nil {
29
+			return 0, fmt.Errorf("cannot flush writer hand-side: %w", err)
30
+		}
31
+
32
+		return s.readBlocking(p, true)
33
+	}
34
+
35
+	return n, err
36
+}
37
+
38
+func (s *syncPair) Write(p []byte) (int, error) {
39
+	return s.writer.Write(p) // nolint: wrapcheck
40
+}
41
+
42
+func (s *syncPair) readBlocking(p []byte, blocking bool) (int, error) {
43
+	var deadline time.Time
44
+
45
+	if !blocking {
46
+		deadline = time.Now().Add(readTimeout)
47
+	}
48
+
49
+	if err := s.reader.SetReadDeadline(deadline); err != nil {
50
+		return 0, fmt.Errorf("cannot set read deadline: %w", err)
51
+	}
52
+
53
+	return s.reader.Read(p) // nolint: wrapcheck
54
+}

Loading…
Cancel
Save