ソースを参照

Correct closing of connections

tags/1.0^2
9seconds 6年前
コミット
56cf90b13d
3個のファイルの変更16行の追加9行の削除
  1. 1
    0
      hub/connection.go
  2. 7
    4
      proxy/direct.go
  3. 8
    5
      proxy/middle.go

+ 1
- 0
hub/connection.go ファイルの表示

@@ -52,6 +52,7 @@ func (c *connection) write(packet conntypes.Packet) error {
52 52
 
53 53
 func (c *connection) shutdown() {
54 54
 	c.shutdownOnce.Do(func() {
55
+		c.conn.Close()
55 56
 		close(c.done)
56 57
 		c.hub.channelBrokenSockets <- c.id
57 58
 	})

+ 7
- 4
proxy/direct.go ファイルの表示

@@ -27,14 +27,17 @@ func directConnection(request *protocol.TelegramRequest) error {
27 27
 	go directPipe(telegramConn, request.ClientConn, wg, request.Logger)
28 28
 	go directPipe(request.ClientConn, telegramConn, wg, request.Logger)
29 29
 
30
-	<-request.Ctx.Done()
31 30
 	wg.Wait()
32 31
 
33
-	return request.Ctx.Err()
32
+	return nil
34 33
 }
35 34
 
36
-func directPipe(dst io.Writer, src io.Reader, wg *sync.WaitGroup, logger *zap.SugaredLogger) {
37
-	defer wg.Done()
35
+func directPipe(dst io.WriteCloser, src io.ReadCloser, wg *sync.WaitGroup, logger *zap.SugaredLogger) {
36
+	defer func() {
37
+		dst.Close()
38
+		src.Close()
39
+		wg.Done()
40
+	}()
38 41
 
39 42
 	buf := make([]byte, directPipeBufferSize)
40 43
 	if _, err := io.CopyBuffer(dst, src, buf); err != nil {

+ 8
- 5
proxy/middle.go ファイルの表示

@@ -32,17 +32,20 @@ func middleConnection(request *protocol.TelegramRequest) error {
32 32
 	go middlePipe(telegramConn, clientConn, wg, request.Logger)
33 33
 	go middlePipe(clientConn, telegramConn, wg, request.Logger)
34 34
 
35
-	<-request.Ctx.Done()
36 35
 	wg.Wait()
37 36
 
38
-	return request.Ctx.Err()
37
+	return nil
39 38
 }
40 39
 
41
-func middlePipe(dst conntypes.PacketAckWriter,
42
-	src conntypes.PacketAckReader,
40
+func middlePipe(dst conntypes.PacketAckWriteCloser,
41
+	src conntypes.PacketAckReadCloser,
43 42
 	wg *sync.WaitGroup,
44 43
 	logger *zap.SugaredLogger) {
45
-	defer wg.Done()
44
+	defer func() {
45
+		dst.Close()
46
+		src.Close()
47
+		wg.Done()
48
+	}()
46 49
 
47 50
 	for {
48 51
 		acks := conntypes.ConnectionAcks{}

読み込み中…
キャンセル
保存