| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- package relay
-
- import (
- "context"
- "io"
- "sync"
- "time"
- )
-
- type Relay struct {
- ctx context.Context
- ctxCancel context.CancelFunc
- logger Logger
- processMutex sync.Mutex
- eastBuffer []byte
- westBuffer []byte
- tickChannel chan struct{}
- errorChannel chan error
- tickTimeout time.Duration
- }
-
- func (r *Relay) Reset() {
- r.processMutex.Lock()
- defer r.processMutex.Unlock()
-
- if r.ctxCancel != nil {
- r.ctxCancel()
- }
-
- r.ctx = nil
- r.ctxCancel = nil
- r.logger = nil
- }
-
- func (r *Relay) Process(eastConn, westConn io.ReadWriteCloser) error {
- r.processMutex.Lock()
- defer r.processMutex.Unlock()
-
- eastConn = conn{
- ReadWriteCloser: eastConn,
- ctx: r.ctx,
- tickChannel: r.tickChannel,
- }
- westConn = conn{
- ReadWriteCloser: westConn,
- ctx: r.ctx,
- tickChannel: r.tickChannel,
- }
-
- wg := &sync.WaitGroup{}
- wg.Add(3) // nolint: gomnd
-
- go r.runObserver(eastConn, westConn, wg)
-
- go r.transmit(eastConn, westConn, r.westBuffer, "west", wg)
-
- r.transmit(westConn, eastConn, r.eastBuffer, "east", wg)
-
- wg.Wait()
-
- select {
- case err := <-r.errorChannel:
- return err
- default:
- return nil
- }
- }
-
- func (r *Relay) transmit(src io.ReadCloser, dst io.WriteCloser,
- buffer []byte, direction string, wg *sync.WaitGroup) {
- defer wg.Done()
-
- defer func() {
- r.ctxCancel()
- src.Close()
- dst.Close()
- }()
-
- if _, err := io.CopyBuffer(dst, src, buffer); err != nil {
- r.logger.Printf("error '%v' happened on direction %s", err, direction)
-
- select {
- case <-r.ctx.Done():
- err = r.ctx.Err()
- default:
- }
-
- select {
- case r.errorChannel <- err:
- default:
- }
- }
- }
-
- func (r *Relay) runObserver(one, another io.Closer, wg *sync.WaitGroup) {
- defer wg.Done()
-
- ticker := time.NewTicker(time.Second)
-
- defer func() {
- one.Close()
- another.Close()
-
- ticker.Stop()
-
- select {
- case <-ticker.C:
- default:
- }
- }()
-
- lastTickAt := time.Now()
-
- for {
- select {
- case <-r.ctx.Done():
- return
- case <-r.tickChannel:
- lastTickAt = time.Now()
- case <-ticker.C:
- if time.Since(lastTickAt) > r.tickTimeout {
- r.logger.Printf("exit due to a timeout")
- r.ctxCancel()
-
- return
- }
- }
- }
- }
|