| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- package stream
-
- import (
- "bytes"
- "io"
- "net"
- "sync"
- "time"
-
- "github.com/9seconds/mtg/conntypes"
- "go.uber.org/zap"
- )
-
- type ReadWriteCloseRewinder interface {
- conntypes.StreamReadWriteCloser
- Rewind()
- }
-
- type wrapperRewind struct {
- parent conntypes.StreamReadWriteCloser
- activeReader io.Reader
- buf bytes.Buffer
- mutex sync.Mutex
- }
-
- func (w *wrapperRewind) Write(p []byte) (int, error) {
- return w.parent.Write(p)
- }
-
- func (w *wrapperRewind) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
- return w.parent.WriteTimeout(p, timeout)
- }
-
- func (w *wrapperRewind) Read(p []byte) (int, error) {
- w.mutex.Lock()
- defer w.mutex.Unlock()
-
- return w.activeReader.Read(p)
- }
-
- func (w *wrapperRewind) ReadTimeout(p []byte, _ time.Duration) (int, error) {
- w.mutex.Lock()
- defer w.mutex.Unlock()
-
- return w.activeReader.Read(p)
- }
-
- func (w *wrapperRewind) Conn() net.Conn {
- return w.parent.Conn()
- }
-
- func (w *wrapperRewind) Logger() *zap.SugaredLogger {
- return w.parent.Logger().Named("rewinded")
- }
-
- func (w *wrapperRewind) LocalAddr() *net.TCPAddr {
- return w.parent.LocalAddr()
- }
-
- func (w *wrapperRewind) RemoteAddr() *net.TCPAddr {
- return w.parent.RemoteAddr()
- }
-
- func (w *wrapperRewind) Close() error {
- w.buf.Reset()
-
- return w.parent.Close()
- }
-
- func (w *wrapperRewind) Rewind() {
- w.mutex.Lock()
- w.activeReader = io.MultiReader(&w.buf, w.parent)
- w.mutex.Unlock()
- }
-
- func NewRewind(parent conntypes.StreamReadWriteCloser) ReadWriteCloseRewinder {
- rv := &wrapperRewind{
- parent: parent,
- }
- rv.activeReader = io.TeeReader(parent, &rv.buf)
-
- return rv
- }
|