瀏覽代碼

Simplify relay

tags/v2.0.0-rc1
9seconds 5 年之前
父節點
當前提交
4c3f42e264
共有 3 個檔案被更改,包括 23 行新增12 行删除
  1. 1
    1
      mtglib/internal/faketls/conn.go
  2. 10
    6
      mtglib/internal/relay/conn.go
  3. 12
    5
      mtglib/internal/relay/relay.go

+ 1
- 1
mtglib/internal/faketls/conn.go 查看文件

@@ -29,11 +29,11 @@ func (c *Conn) Read(p []byte) (int, error) {
29 29
 		}
30 30
 
31 31
 		switch rec.Type { // nolint: exhaustive
32
-		case record.TypeChangeCipherSpec:
33 32
 		case record.TypeApplicationData:
34 33
 			rec.Payload.WriteTo(&c.readBuffer) // nolint: errcheck
35 34
 
36 35
 			return c.readBuffer.Read(p)
36
+		case record.TypeChangeCipherSpec:
37 37
 		default:
38 38
 			return 0, fmt.Errorf("unsupported record type %v", rec.Type)
39 39
 		}

+ 10
- 6
mtglib/internal/relay/conn.go 查看文件

@@ -1,19 +1,23 @@
1 1
 package relay
2 2
 
3
-import "io"
3
+import (
4
+	"context"
5
+	"io"
6
+)
4 7
 
5 8
 type conn struct {
6 9
 	io.ReadWriteCloser
7 10
 
8
-	relay *Relay
11
+	ctx         context.Context
12
+	tickChannel chan struct{}
9 13
 }
10 14
 
11 15
 func (c conn) Read(p []byte) (int, error) {
12 16
 	n, err := c.ReadWriteCloser.Read(p)
13 17
 
14 18
 	select {
15
-	case <-c.relay.ctx.Done():
16
-	case c.relay.tickChannel <- struct{}{}:
19
+	case <-c.ctx.Done():
20
+	case c.tickChannel <- struct{}{}:
17 21
 	}
18 22
 
19 23
 	return n, err // nolint: wrapcheck
@@ -23,8 +27,8 @@ func (c conn) Write(p []byte) (int, error) {
23 27
 	n, err := c.ReadWriteCloser.Write(p)
24 28
 
25 29
 	select {
26
-	case <-c.relay.ctx.Done():
27
-	case c.relay.tickChannel <- struct{}{}:
30
+	case <-c.ctx.Done():
31
+	case c.tickChannel <- struct{}{}:
28 32
 	}
29 33
 
30 34
 	return n, err // nolint: wrapcheck

+ 12
- 5
mtglib/internal/relay/relay.go 查看文件

@@ -21,11 +21,13 @@ type Relay struct {
21 21
 func (r *Relay) Process(eastConn, westConn io.ReadWriteCloser) error {
22 22
 	eastConn = conn{
23 23
 		ReadWriteCloser: eastConn,
24
-		relay:           r,
24
+		ctx:             r.ctx,
25
+		tickChannel:     r.tickChannel,
25 26
 	}
26 27
 	westConn = conn{
27 28
 		ReadWriteCloser: westConn,
28
-		relay:           r,
29
+		ctx:             r.ctx,
30
+		tickChannel:     r.tickChannel,
29 31
 	}
30 32
 
31 33
 	defer func() {
@@ -37,7 +39,7 @@ func (r *Relay) Process(eastConn, westConn io.ReadWriteCloser) error {
37 39
 	wg := &sync.WaitGroup{}
38 40
 	wg.Add(3) // nolint: gomnd
39 41
 
40
-	go r.runObserver(r.ctx, wg)
42
+	go r.runObserver(wg)
41 43
 
42 44
 	go r.transmit(eastConn, westConn, r.westBuffer, "west", wg)
43 45
 
@@ -66,13 +68,18 @@ func (r *Relay) transmit(src io.ReadCloser, dst io.WriteCloser,
66 68
 
67 69
 		select {
68 70
 		case <-r.ctx.Done():
71
+			err = r.ctx.Err()
72
+		default:
73
+		}
74
+
75
+		select {
69 76
 		case r.errorChannel <- err:
70 77
 		default:
71 78
 		}
72 79
 	}
73 80
 }
74 81
 
75
-func (r *Relay) runObserver(ctx context.Context, wg *sync.WaitGroup) {
82
+func (r *Relay) runObserver(wg *sync.WaitGroup) {
76 83
 	ticker := time.NewTicker(time.Second)
77 84
 
78 85
 	defer func() {
@@ -90,7 +97,7 @@ func (r *Relay) runObserver(ctx context.Context, wg *sync.WaitGroup) {
90 97
 
91 98
 	for {
92 99
 		select {
93
-		case <-ctx.Done():
100
+		case <-r.ctx.Done():
94 101
 			return
95 102
 		case <-r.tickChannel:
96 103
 			lastTickAt = time.Now()

Loading…
取消
儲存