Просмотр исходного кода

Merge pull request #177 from 9seconds/faketls-hotfix

Correct rewinding for faketls
tags/v1.0.9
Sergey Arkhipov 5 лет назад
Родитель
Сommit
ed495b3800
Аккаунт пользователя с таким Email не найден
1 измененных файлов: 12 добавлений и 34 удалений
  1. 12
    34
      wrappers/stream/rewind.go

+ 12
- 34
wrappers/stream/rewind.go Просмотреть файл

@@ -2,7 +2,6 @@ package stream
2 2
 
3 3
 import (
4 4
 	"bytes"
5
-	"errors"
6 5
 	"io"
7 6
 	"net"
8 7
 	"sync"
@@ -18,10 +17,10 @@ type ReadWriteCloseRewinder interface {
18 17
 }
19 18
 
20 19
 type wrapperRewind struct {
21
-	parent   conntypes.StreamReadWriteCloser
22
-	buf      bytes.Buffer
23
-	mutex    sync.Mutex
24
-	rewinded bool
20
+	parent       conntypes.StreamReadWriteCloser
21
+	activeReader io.Reader
22
+	buf          bytes.Buffer
23
+	mutex        sync.Mutex
25 24
 }
26 25
 
27 26
 func (w *wrapperRewind) Write(p []byte) (int, error) {
@@ -36,38 +35,14 @@ func (w *wrapperRewind) Read(p []byte) (int, error) {
36 35
 	w.mutex.Lock()
37 36
 	defer w.mutex.Unlock()
38 37
 
39
-	if w.rewinded {
40
-		if n, err := w.buf.Read(p); errors.Is(err, io.EOF) {
41
-			return n, err // nolint: wrapcheck
42
-		}
43
-	}
44
-
45
-	n, err := w.parent.Read(p)
46
-
47
-	if !w.rewinded {
48
-		w.buf.Write(p[:n])
49
-	}
50
-
51
-	return n, err // nolint: wrapcheck
38
+	return w.activeReader.Read(p)
52 39
 }
53 40
 
54
-func (w *wrapperRewind) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
41
+func (w *wrapperRewind) ReadTimeout(p []byte, _ time.Duration) (int, error) {
55 42
 	w.mutex.Lock()
56 43
 	defer w.mutex.Unlock()
57 44
 
58
-	if w.rewinded {
59
-		if n, err := w.buf.Read(p); errors.Is(err, io.EOF) {
60
-			return n, err // nolint: wrapcheck
61
-		}
62
-	}
63
-
64
-	n, err := w.parent.ReadTimeout(p, timeout)
65
-
66
-	if !w.rewinded {
67
-		w.buf.Write(p[:n])
68
-	}
69
-
70
-	return n, err // nolint: wrapcheck
45
+	return w.activeReader.Read(p)
71 46
 }
72 47
 
73 48
 func (w *wrapperRewind) Conn() net.Conn {
@@ -94,12 +69,15 @@ func (w *wrapperRewind) Close() error {
94 69
 
95 70
 func (w *wrapperRewind) Rewind() {
96 71
 	w.mutex.Lock()
97
-	w.rewinded = true
72
+	w.activeReader = io.MultiReader(&w.buf, w.parent)
98 73
 	w.mutex.Unlock()
99 74
 }
100 75
 
101 76
 func NewRewind(parent conntypes.StreamReadWriteCloser) ReadWriteCloseRewinder {
102
-	return &wrapperRewind{
77
+	rv := &wrapperRewind{
103 78
 		parent: parent,
104 79
 	}
80
+	rv.activeReader = io.TeeReader(parent, &rv.buf)
81
+
82
+	return rv
105 83
 }

Загрузка…
Отмена
Сохранить