| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- package doppel
-
- import (
- "bytes"
- "context"
- "sync"
-
- "github.com/9seconds/mtg/v2/essentials"
- "github.com/9seconds/mtg/v2/mtglib/internal/tls"
- )
-
- type Conn struct {
- essentials.Conn
-
- p *connPayload
- }
-
- type connPayload struct {
- ctx context.Context
- ctxCancel context.CancelCauseFunc
- clock Clock
- wg sync.WaitGroup
- syncWriteLock sync.RWMutex
- writeStream bytes.Buffer
- writeCond *sync.Cond
- }
-
- func (c Conn) Write(p []byte) (int, error) {
- c.p.syncWriteLock.RLock()
- defer c.p.syncWriteLock.RUnlock()
-
- c.p.writeCond.L.Lock()
- c.p.writeStream.Write(p)
- c.p.writeCond.L.Unlock()
-
- return len(p), context.Cause(c.p.ctx)
- }
-
- func (c Conn) SyncWrite(p []byte) (int, error) {
- c.p.syncWriteLock.Lock()
- defer c.p.syncWriteLock.Unlock()
-
- c.p.writeCond.L.Lock()
- // wait until buffer is exhausted
- for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
- c.p.writeCond.Wait()
- }
- c.p.writeStream.Write(p)
- c.p.writeCond.L.Unlock()
-
- if err := context.Cause(c.p.ctx); err != nil {
- return len(p), err
- }
-
- c.p.writeCond.L.Lock()
- // wait until data will be sent
- for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
- c.p.writeCond.Wait()
- }
- c.p.writeCond.L.Unlock()
-
- return len(p), context.Cause(c.p.ctx)
- }
-
- func (c Conn) Start() {
- c.p.wg.Go(func() {
- c.start()
- })
- }
-
- func (c Conn) start() {
- defer c.p.writeCond.Broadcast()
-
- buf := [tls.MaxRecordSize]byte{}
-
- for {
- select {
- case <-c.p.ctx.Done():
- return
- case <-c.p.clock.tick:
- }
-
- c.p.writeCond.L.Lock()
- n, err := c.p.writeStream.Read(buf[:c.p.clock.stats.Size()])
- c.p.writeCond.L.Unlock()
-
- if n == 0 || err != nil {
- continue
- }
-
- if err := tls.WriteRecord(c.Conn, buf[:n]); err != nil {
- c.p.ctxCancel(err)
- return
- }
-
- c.p.writeCond.Signal()
- }
- }
-
- func (c Conn) Stop() {
- c.p.ctxCancel(nil)
- c.p.wg.Wait()
- }
-
- func NewConn(ctx context.Context, conn essentials.Conn, stats *Stats) Conn {
- ctx, cancel := context.WithCancelCause(ctx)
- rv := Conn{
- Conn: conn,
- p: &connPayload{
- ctx: ctx,
- ctxCancel: cancel,
- writeCond: sync.NewCond(&sync.Mutex{}),
- clock: Clock{
- stats: stats,
- tick: make(chan struct{}),
- },
- },
- }
-
- rv.p.writeStream.Grow(tls.DefaultBufferSize)
-
- rv.p.wg.Go(func() {
- rv.p.clock.Start(ctx)
- })
- rv.p.wg.Go(func() {
- rv.start()
- })
-
- return rv
- }
|