Просмотр исходного кода

Add doppel and tls packages

tags/v2.2.0^2^2
9seconds 1 месяц назад
Родитель
Сommit
1182b9ef6f

+ 35
- 0
mtglib/internal/doppel/clock.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"context"
5
+	"time"
6
+)
7
+
8
+type Clock struct {
9
+	stats *Stats
10
+	tick  chan struct{}
11
+}
12
+
13
+func (c Clock) Start(ctx context.Context) {
14
+	tickTock := time.NewTimer(c.stats.Delay())
15
+	defer func() {
16
+		tickTock.Stop()
17
+		select {
18
+		case <-tickTock.C:
19
+		default:
20
+		}
21
+	}()
22
+
23
+	for {
24
+		select {
25
+		case <-ctx.Done():
26
+			return
27
+		case <-tickTock.C:
28
+			select {
29
+			case <-ctx.Done():
30
+			case c.tick <- struct{}{}:
31
+			}
32
+			tickTock.Reset(c.stats.Delay())
33
+		}
34
+	}
35
+}

+ 80
- 0
mtglib/internal/doppel/clock_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"context"
5
+	"sync"
6
+	"testing"
7
+	"time"
8
+
9
+	"github.com/stretchr/testify/suite"
10
+)
11
+
12
+type ClockTestSuite struct {
13
+	suite.Suite
14
+
15
+	clock     Clock
16
+	wg        sync.WaitGroup
17
+	ctx       context.Context
18
+	ctxCancel context.CancelFunc
19
+}
20
+
21
+func (suite *ClockTestSuite) SetupTest() {
22
+	ctx, cancel := context.WithCancel(context.Background())
23
+
24
+	suite.ctx = ctx
25
+	suite.ctxCancel = cancel
26
+	suite.clock = Clock{
27
+		stats: &Stats{
28
+			k:      StatsDefaultK,
29
+			lambda: StatsDefaultLambda,
30
+		},
31
+		tick: make(chan struct{}),
32
+	}
33
+
34
+	suite.wg.Go(func() {
35
+		suite.clock.Start(suite.ctx)
36
+	})
37
+}
38
+
39
+func (suite *ClockTestSuite) TearDownTest() {
40
+	suite.ctxCancel()
41
+	suite.wg.Wait()
42
+}
43
+
44
+func (suite *ClockTestSuite) TestTicks() {
45
+	received := 0
46
+
47
+	for range 3 {
48
+		select {
49
+		case <-suite.clock.tick:
50
+			received++
51
+		case <-time.After(2 * time.Second):
52
+			suite.Fail("timed out waiting for tick")
53
+		}
54
+	}
55
+
56
+	suite.Equal(3, received)
57
+}
58
+
59
+func (suite *ClockTestSuite) TestStopsOnCancel() {
60
+	select {
61
+	case <-suite.clock.tick:
62
+	case <-time.After(2 * time.Second):
63
+		suite.Fail("timed out waiting for first tick")
64
+	}
65
+
66
+	suite.ctxCancel()
67
+
68
+	time.Sleep(50 * time.Millisecond)
69
+
70
+	select {
71
+	case <-suite.clock.tick:
72
+		suite.Fail("received tick after cancel")
73
+	default:
74
+	}
75
+}
76
+
77
+func TestClock(t *testing.T) {
78
+	t.Parallel()
79
+	suite.Run(t, &ClockTestSuite{})
80
+}

+ 95
- 0
mtglib/internal/doppel/conn.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"bytes"
5
+	"context"
6
+	"sync"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
9
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
10
+)
11
+
12
+type Conn struct {
13
+	essentials.Conn
14
+
15
+	p *connPayload
16
+}
17
+
18
+type connPayload struct {
19
+	ctx         context.Context
20
+	ctxCancel   context.CancelCauseFunc
21
+	clock       Clock
22
+	wg          sync.WaitGroup
23
+	writeLock   sync.Mutex
24
+	writeStream bytes.Buffer
25
+}
26
+
27
+func (c Conn) Write(p []byte) (int, error) {
28
+	c.p.writeLock.Lock()
29
+	c.p.writeStream.Write(p)
30
+	c.p.writeLock.Unlock()
31
+
32
+	return len(p), context.Cause(c.p.ctx)
33
+}
34
+
35
+func (c Conn) Start() {
36
+	c.p.wg.Go(func() {
37
+		c.start()
38
+	})
39
+}
40
+
41
+func (c Conn) start() {
42
+	buf := [tls.MaxRecordSize]byte{}
43
+
44
+	for {
45
+		select {
46
+		case <-c.p.ctx.Done():
47
+			return
48
+		case <-c.p.clock.tick:
49
+		}
50
+
51
+		c.p.writeLock.Lock()
52
+		n, err := c.p.writeStream.Read(buf[:c.p.clock.stats.Size()])
53
+		c.p.writeLock.Unlock()
54
+
55
+		if n == 0 || err != nil {
56
+			continue
57
+		}
58
+
59
+		if err := tls.WriteRecord(c.Conn, buf[:n]); err != nil {
60
+			c.p.ctxCancel(err)
61
+			return
62
+		}
63
+	}
64
+}
65
+
66
+func (c Conn) Stop() {
67
+	c.p.ctxCancel(nil)
68
+	c.p.wg.Wait()
69
+}
70
+
71
+func NewConn(ctx context.Context, conn essentials.Conn, stats *Stats) Conn {
72
+	ctx, cancel := context.WithCancelCause(ctx)
73
+	rv := Conn{
74
+		Conn: conn,
75
+		p: &connPayload{
76
+			ctx:       ctx,
77
+			ctxCancel: cancel,
78
+			clock: Clock{
79
+				stats: stats,
80
+				tick:  make(chan struct{}),
81
+			},
82
+		},
83
+	}
84
+
85
+	rv.p.writeStream.Grow(tls.DefaultBufferSize)
86
+
87
+	rv.p.wg.Go(func() {
88
+		rv.p.clock.Start(ctx)
89
+	})
90
+	rv.p.wg.Go(func() {
91
+		rv.start()
92
+	})
93
+
94
+	return rv
95
+}

+ 163
- 0
mtglib/internal/doppel/conn_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"bytes"
5
+	"context"
6
+	"encoding/binary"
7
+	"errors"
8
+	"io"
9
+	"sync"
10
+	"testing"
11
+	"time"
12
+
13
+	"github.com/9seconds/mtg/v2/internal/testlib"
14
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
15
+	"github.com/stretchr/testify/mock"
16
+	"github.com/stretchr/testify/suite"
17
+)
18
+
19
+type ConnMock struct {
20
+	testlib.EssentialsConnMock
21
+
22
+	mu          sync.Mutex
23
+	writeBuffer bytes.Buffer
24
+}
25
+
26
+func (m *ConnMock) Write(p []byte) (int, error) {
27
+	args := m.Called(p)
28
+	if err := args.Error(1); err != nil {
29
+		return args.Int(0), err
30
+	}
31
+
32
+	m.mu.Lock()
33
+	defer m.mu.Unlock()
34
+
35
+	return m.writeBuffer.Write(p)
36
+}
37
+
38
+func (m *ConnMock) Written() []byte {
39
+	m.mu.Lock()
40
+	defer m.mu.Unlock()
41
+
42
+	return bytes.Clone(m.writeBuffer.Bytes())
43
+}
44
+
45
+type ConnTestSuite struct {
46
+	suite.Suite
47
+
48
+	connMock  *ConnMock
49
+	ctx       context.Context
50
+	ctxCancel context.CancelFunc
51
+}
52
+
53
+func (suite *ConnTestSuite) SetupTest() {
54
+	ctx, cancel := context.WithCancel(context.Background())
55
+	suite.ctx = ctx
56
+	suite.ctxCancel = cancel
57
+	suite.connMock = &ConnMock{}
58
+}
59
+
60
+func (suite *ConnTestSuite) TearDownTest() {
61
+	suite.ctxCancel()
62
+	suite.connMock.AssertExpectations(suite.T())
63
+}
64
+
65
+func (suite *ConnTestSuite) makeConn() Conn {
66
+	return NewConn(suite.ctx, suite.connMock, &Stats{
67
+		k:      2.0,
68
+		lambda: 0.01,
69
+	})
70
+}
71
+
72
+func (suite *ConnTestSuite) TestWriteBuffersData() {
73
+	suite.connMock.
74
+		On("Write", mock.AnythingOfType("[]uint8")).
75
+		Return(0, nil).
76
+		Maybe()
77
+
78
+	c := suite.makeConn()
79
+	defer c.Stop()
80
+
81
+	n, err := c.Write([]byte{1, 2, 3})
82
+	suite.NoError(err)
83
+	suite.Equal(3, n)
84
+}
85
+
86
+func (suite *ConnTestSuite) TestWriteOutputsTLSRecords() {
87
+	suite.connMock.
88
+		On("Write", mock.AnythingOfType("[]uint8")).
89
+		Return(0, nil).
90
+		Maybe()
91
+
92
+	c := suite.makeConn()
93
+
94
+	payload := []byte("hello doppelganger")
95
+	_, err := c.Write(payload)
96
+	suite.NoError(err)
97
+
98
+	suite.Eventually(func() bool {
99
+		return len(suite.connMock.Written()) > 0
100
+	}, 2*time.Second, time.Millisecond)
101
+
102
+	c.Stop()
103
+
104
+	assembled := &bytes.Buffer{}
105
+	reader := bytes.NewReader(suite.connMock.Written())
106
+
107
+	for {
108
+		header := make([]byte, tls.SizeHeader)
109
+		if _, err := io.ReadFull(reader, header); err != nil {
110
+			break
111
+		}
112
+
113
+		suite.Equal(byte(tls.TypeApplicationData), header[0])
114
+		suite.Equal(tls.TLSVersion[:], header[tls.SizeRecordType:tls.SizeRecordType+tls.SizeVersion])
115
+
116
+		length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
117
+		suite.Greater(length, uint16(0))
118
+
119
+		rec := make([]byte, length)
120
+		_, err := io.ReadFull(reader, rec)
121
+		suite.NoError(err)
122
+
123
+		assembled.Write(rec)
124
+	}
125
+
126
+	suite.Equal(payload, assembled.Bytes())
127
+}
128
+
129
+func (suite *ConnTestSuite) TestWriteReturnsErrorAfterStop() {
130
+	suite.connMock.
131
+		On("Write", mock.AnythingOfType("[]uint8")).
132
+		Return(0, nil).
133
+		Maybe()
134
+
135
+	c := suite.makeConn()
136
+	c.Stop()
137
+
138
+	time.Sleep(10 * time.Millisecond)
139
+
140
+	_, err := c.Write([]byte{1})
141
+	suite.Error(err)
142
+}
143
+
144
+func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
145
+	suite.connMock.
146
+		On("Write", mock.AnythingOfType("[]uint8")).
147
+		Return(0, errors.New("connection reset")).
148
+		Maybe()
149
+
150
+	c := suite.makeConn()
151
+
152
+	_, _ = c.Write([]byte("data"))
153
+
154
+	suite.Eventually(func() bool {
155
+		_, err := c.Write([]byte{1})
156
+		return err != nil
157
+	}, 2*time.Second, time.Millisecond)
158
+}
159
+
160
+func TestConn(t *testing.T) {
161
+	t.Parallel()
162
+	suite.Run(t, &ConnTestSuite{})
163
+}

+ 173
- 0
mtglib/internal/doppel/ganger.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"context"
5
+	"sync"
6
+	"time"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
9
+)
10
+
11
+const (
12
+	DoppelGangerMaxDurations     = 4096
13
+	DoppelGangerScoutMissionEach = 30 * time.Minute
14
+	DoppelGangerScoutRepeats     = 10
15
+)
16
+
17
+type gangerConnRequest struct {
18
+	ret     chan Conn
19
+	payload essentials.Conn
20
+}
21
+
22
+type Ganger struct {
23
+	ctx       context.Context
24
+	ctxCancel context.CancelFunc
25
+	logger    Logger
26
+	wg        sync.WaitGroup
27
+
28
+	scout               Scout
29
+	scoutMissionEach    time.Duration
30
+	scoutMissionRepeats int
31
+
32
+	stats     *Stats
33
+	durations []time.Duration
34
+
35
+	connRequests chan gangerConnRequest
36
+}
37
+
38
+func (g *Ganger) Shutdown() {
39
+	g.ctxCancel()
40
+	g.wg.Wait()
41
+}
42
+
43
+func (g *Ganger) Run() {
44
+	g.wg.Go(func() {
45
+		g.run()
46
+	})
47
+}
48
+
49
+func (g *Ganger) NewConn(conn essentials.Conn) (Conn, error) {
50
+	req := gangerConnRequest{
51
+		ret:     make(chan Conn),
52
+		payload: conn,
53
+	}
54
+	defer close(req.ret)
55
+
56
+	select {
57
+	case <-g.ctx.Done():
58
+		return Conn{}, context.Cause(g.ctx)
59
+	case g.connRequests <- req:
60
+	}
61
+
62
+	select {
63
+	case <-g.ctx.Done():
64
+		return Conn{}, context.Cause(g.ctx)
65
+	case conn := <-req.ret:
66
+		return conn, nil
67
+	}
68
+}
69
+
70
+func (g *Ganger) run() {
71
+	scoutTicker := time.NewTicker(g.scoutMissionEach)
72
+	defer func() {
73
+		scoutTicker.Stop()
74
+
75
+		select {
76
+		case <-scoutTicker.C:
77
+		default:
78
+		}
79
+	}()
80
+
81
+	scoutCollectedChan := make(chan []time.Duration)
82
+	currentScoutCollectedChan := scoutCollectedChan
83
+
84
+	updatedStatsChan := make(chan *Stats)
85
+
86
+	g.wg.Go(func() {
87
+		g.runScoutMission(scoutCollectedChan)
88
+	})
89
+
90
+	for {
91
+		select {
92
+		case <-g.ctx.Done():
93
+			return
94
+		case durations := <-currentScoutCollectedChan:
95
+			g.durations = append(g.durations, durations...)
96
+			if len(g.durations) > DoppelGangerMaxDurations {
97
+				g.durations = g.durations[len(g.durations)-DoppelGangerMaxDurations:]
98
+			}
99
+
100
+			currentScoutCollectedChan = nil
101
+			g.wg.Go(func() {
102
+				select {
103
+				case <-g.ctx.Done():
104
+				case updatedStatsChan <- NewStats(durations):
105
+				}
106
+			})
107
+		case stats := <-updatedStatsChan:
108
+			g.stats = stats
109
+			currentScoutCollectedChan = scoutCollectedChan
110
+		case <-scoutTicker.C:
111
+			g.wg.Go(func() {
112
+				g.runScoutMission(scoutCollectedChan)
113
+			})
114
+		case req := <-g.connRequests:
115
+			select {
116
+			case <-g.ctx.Done():
117
+			case req.ret <- NewConn(g.ctx, req.payload, g.stats):
118
+			}
119
+		}
120
+	}
121
+}
122
+
123
+func (g *Ganger) runScoutMission(rvChan chan<- []time.Duration) {
124
+	durations := []time.Duration{}
125
+
126
+	for range g.scoutMissionRepeats {
127
+		learned, err := g.scout.Learn(g.ctx)
128
+		if err != nil {
129
+			g.logger.WarningError("cannot learn", err)
130
+			continue
131
+		}
132
+		durations = append(durations, learned...)
133
+	}
134
+
135
+	select {
136
+	case <-g.ctx.Done():
137
+		return
138
+	case rvChan <- durations:
139
+	}
140
+}
141
+
142
+func NewGanger(
143
+	ctx context.Context,
144
+	network Network,
145
+	logger Logger,
146
+	scoutEach time.Duration,
147
+	scoutRepeats int,
148
+	urls []string,
149
+) *Ganger {
150
+	ctx, cancel := context.WithCancel(ctx)
151
+
152
+	if scoutEach == 0 {
153
+		scoutEach = DoppelGangerScoutMissionEach
154
+	}
155
+
156
+	if scoutRepeats == 0 {
157
+		scoutRepeats = DoppelGangerScoutRepeats
158
+	}
159
+
160
+	return &Ganger{
161
+		ctx:                 ctx,
162
+		ctxCancel:           cancel,
163
+		logger:              logger,
164
+		scoutMissionEach:    scoutEach,
165
+		scoutMissionRepeats: scoutRepeats,
166
+		stats: &Stats{
167
+			k:      StatsDefaultK,
168
+			lambda: StatsDefaultLambda,
169
+		},
170
+		scout:        NewScout(network, urls),
171
+		connRequests: make(chan gangerConnRequest),
172
+	}
173
+}

+ 107
- 0
mtglib/internal/doppel/ganger_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"bytes"
5
+	"sync"
6
+	"testing"
7
+	"time"
8
+
9
+	"github.com/9seconds/mtg/v2/internal/testlib"
10
+	"github.com/stretchr/testify/mock"
11
+	"github.com/stretchr/testify/suite"
12
+)
13
+
14
+type GangerTestSuite struct {
15
+	TLSServerTestSuite
16
+
17
+	log *LoggerMock
18
+	g   *Ganger
19
+}
20
+
21
+func (suite *GangerTestSuite) SetupTest() {
22
+	suite.TLSServerTestSuite.SetupTest()
23
+
24
+	suite.log = &LoggerMock{}
25
+	suite.log.
26
+		On("Info", mock.AnythingOfType("string")).
27
+		Maybe()
28
+	suite.log.
29
+		On("WarningError", mock.AnythingOfType("string"), mock.Anything).
30
+		Maybe()
31
+
32
+	suite.g = NewGanger(suite.ctx, suite.network, suite.log, time.Hour, 1, suite.urls)
33
+	suite.g.Run()
34
+}
35
+
36
+func (suite *GangerTestSuite) TearDownTest() {
37
+	suite.g.Shutdown()
38
+
39
+	suite.log.AssertExpectations(suite.T())
40
+	suite.TLSServerTestSuite.TearDownTest()
41
+}
42
+
43
+func (suite *GangerTestSuite) TestNewConnAfterShutdown() {
44
+	suite.g.Shutdown()
45
+	connMock := &testlib.EssentialsConnMock{}
46
+
47
+	_, err := suite.g.NewConn(connMock)
48
+	suite.Error(err)
49
+}
50
+
51
+func (suite *GangerTestSuite) TestNewConnWhileRunning() {
52
+	connMock := &testlib.EssentialsConnMock{}
53
+	connMock.
54
+		On("Write", mock.AnythingOfType("[]uint8")).
55
+		Return(0, nil).
56
+		Maybe()
57
+	connMock.On("Close").
58
+		Return(nil).
59
+		Maybe()
60
+
61
+	conn, err := suite.g.NewConn(connMock)
62
+	suite.NoError(err)
63
+
64
+	conn.Stop()
65
+}
66
+
67
+func (suite *GangerTestSuite) TestNewConnWriteProducesTLSRecords() {
68
+	var (
69
+		mu  sync.Mutex
70
+		buf bytes.Buffer
71
+	)
72
+
73
+	connMock := &testlib.EssentialsConnMock{}
74
+	connMock.On("Write", mock.AnythingOfType("[]uint8")).
75
+		Run(func(args mock.Arguments) {
76
+			mu.Lock()
77
+			buf.Write(args.Get(0).([]byte))
78
+			mu.Unlock()
79
+		}).
80
+		Return(0, nil).
81
+		Maybe()
82
+	connMock.On("Close").
83
+		Return(nil).
84
+		Maybe()
85
+
86
+	conn, err := suite.g.NewConn(connMock)
87
+	suite.NoError(err)
88
+
89
+	payload := bytes.Repeat([]byte("x"), 512)
90
+	_, err = conn.Write(payload)
91
+	suite.NoError(err)
92
+
93
+	time.Sleep(500 * time.Millisecond)
94
+	conn.Stop()
95
+
96
+	mu.Lock()
97
+	written := buf.Bytes()
98
+	mu.Unlock()
99
+
100
+	suite.NotEmpty(written)
101
+}
102
+
103
+func TestGanger(t *testing.T) {
104
+	t.Parallel()
105
+
106
+	suite.Run(t, &GangerTestSuite{})
107
+}

+ 38
- 0
mtglib/internal/doppel/init.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"context"
5
+	"net/http"
6
+	"time"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
9
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
10
+)
11
+
12
+const (
13
+	// Please see Stats description
14
+	// https://blog.cloudflare.com/optimizing-tls-over-tcp-to-reduce-latency/
15
+	// https://github.com/cloudflare/sslconfig/blob/master/patches/nginx__dynamic_tls_records.patch
16
+	TLSRecordSizeStart = 1369
17
+	TLSRecordSizeAccel = 4229
18
+	TLSRecordSizeMax   = 16384 - tls.SizeHeader
19
+
20
+	TLSCounterAccelAfter = 40
21
+	TLSCounterMaxAfter   = TLSCounterAccelAfter + 20
22
+
23
+	TLSRecordSizeResetAfter = time.Second
24
+)
25
+
26
+// copypasted from mtglib
27
+type Network interface {
28
+	// Dial establishes context-free TCP connections.
29
+	Dial(network, address string) (essentials.Conn, error)
30
+
31
+	// DialContext dials using a context. This is a preferrable way of
32
+	// establishing TCP connections.
33
+	DialContext(ctx context.Context, network, address string) (essentials.Conn, error)
34
+
35
+	// MakeHTTPClient build an HTTP client with given dial function. If nothing is
36
+	// provided, then DialContext of this interface is going to be used.
37
+	MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client
38
+}

+ 104
- 0
mtglib/internal/doppel/init_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"context"
5
+	"crypto/tls"
6
+	"net"
7
+	"net/http"
8
+	"net/http/httptest"
9
+	"time"
10
+
11
+	"github.com/9seconds/mtg/v2/essentials"
12
+	"github.com/stretchr/testify/mock"
13
+	"github.com/stretchr/testify/suite"
14
+)
15
+
16
+type SimpleNetwork struct {
17
+}
18
+
19
+func (s SimpleNetwork) Dial(network, address string) (essentials.Conn, error) {
20
+	return s.DialContext(context.Background(), network, address)
21
+}
22
+
23
+func (s SimpleNetwork) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) {
24
+	d := &net.Dialer{}
25
+
26
+	conn, err := d.DialContext(ctx, network, address)
27
+	if err != nil {
28
+		return nil, err
29
+	}
30
+
31
+	return conn.(*net.TCPConn), nil
32
+}
33
+
34
+func (s SimpleNetwork) MakeHTTPClient(dialFunc func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client {
35
+	if dialFunc == nil {
36
+		dialFunc = s.DialContext
37
+	}
38
+
39
+	return &http.Client{
40
+		Transport: &http.Transport{
41
+			TLSClientConfig: &tls.Config{
42
+				InsecureSkipVerify: true, //nolint: gosec
43
+			},
44
+			DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
45
+				return dialFunc(ctx, network, address)
46
+			},
47
+		},
48
+	}
49
+}
50
+
51
+type TLSServerTestSuite struct {
52
+	suite.Suite
53
+
54
+	tlsServer *httptest.Server
55
+	ctx       context.Context
56
+	ctxCancel context.CancelFunc
57
+	network   SimpleNetwork
58
+	urls      []string
59
+}
60
+
61
+func (suite *TLSServerTestSuite) SetupSuite() {
62
+	suite.tlsServer = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
63
+		w.WriteHeader(http.StatusOK)
64
+		w.Header().Add("Hello", "how long")
65
+
66
+		if _, err := w.Write([]byte{1, 2, 3}); err != nil {
67
+			panic(err)
68
+		}
69
+
70
+		time.Sleep(5 * time.Millisecond)
71
+
72
+		if _, err := w.Write([]byte{1, 2, 3}); err != nil {
73
+			panic(err)
74
+		}
75
+	}))
76
+	suite.urls = []string{suite.tlsServer.URL}
77
+}
78
+
79
+func (suite *TLSServerTestSuite) SetupTest() {
80
+	ctx, cancel := context.WithCancel(context.Background())
81
+	suite.ctx = ctx
82
+	suite.ctxCancel = cancel
83
+}
84
+
85
+func (suite *TLSServerTestSuite) TearDownTest() {
86
+	suite.ctxCancel()
87
+	suite.tlsServer.CloseClientConnections()
88
+}
89
+
90
+func (suite *TLSServerTestSuite) TearDownSuite() {
91
+	suite.tlsServer.Close()
92
+}
93
+
94
+type LoggerMock struct {
95
+	mock.Mock
96
+}
97
+
98
+func (l *LoggerMock) Info(msg string) {
99
+	l.Called(msg)
100
+}
101
+
102
+func (l *LoggerMock) WarningError(msg string, err error) {
103
+	l.Called(msg, err)
104
+}

+ 6
- 0
mtglib/internal/doppel/logger.go Просмотреть файл

1
+package doppel
2
+
3
+type Logger interface {
4
+	Info(msg string)
5
+	WarningError(msg string, err error)
6
+}

+ 104
- 0
mtglib/internal/doppel/scout.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"context"
5
+	"fmt"
6
+	"io"
7
+	"net/http"
8
+	"strings"
9
+	"time"
10
+
11
+	"github.com/9seconds/mtg/v2/essentials"
12
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
13
+)
14
+
15
+type Scout struct {
16
+	network Network
17
+	urls    []string
18
+}
19
+
20
+func (s Scout) Learn(ctx context.Context) ([]time.Duration, error) {
21
+	var durations []time.Duration
22
+
23
+	for _, url := range s.urls {
24
+		learned, err := s.learn(ctx, url)
25
+		if err != nil {
26
+			return nil, err
27
+		}
28
+
29
+		durations = append(durations, learned...)
30
+	}
31
+
32
+	return durations, nil
33
+}
34
+
35
+func (s Scout) learn(ctx context.Context, url string) ([]time.Duration, error) {
36
+	client, results := s.makeClient()
37
+
38
+	if !strings.HasPrefix(url, "https://") {
39
+		return nil, fmt.Errorf("url %s must be https", url)
40
+	}
41
+
42
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
43
+	if err != nil {
44
+		return nil, err
45
+	}
46
+
47
+	resp, err := client.Do(req)
48
+	if resp != nil {
49
+		io.Copy(io.Discard, resp.Body) //nolint: errcheck
50
+		resp.Body.Close()              //nolint: errcheck
51
+		client.CloseIdleConnections()
52
+	}
53
+
54
+	if err != nil || len(results.data) == 0 {
55
+		return nil, err
56
+	}
57
+
58
+	durations := []time.Duration{}
59
+	lastTimestamp := time.Time{}
60
+
61
+	for i, v := range results.data {
62
+		if v.recordType != tls.TypeApplicationData {
63
+			continue
64
+		}
65
+
66
+		if lastTimestamp.IsZero() {
67
+			if i > 0 {
68
+				lastTimestamp = results.data[i-1].timestamp
69
+			} else {
70
+				lastTimestamp = v.timestamp
71
+			}
72
+		}
73
+
74
+		durations = append(durations, v.timestamp.Sub(lastTimestamp))
75
+		lastTimestamp = v.timestamp
76
+	}
77
+
78
+	return durations, nil
79
+}
80
+
81
+func (s Scout) makeClient() (*http.Client, *ScoutConnCollected) {
82
+	collected := NewScoutConnCollected()
83
+	client := s.network.MakeHTTPClient(func(
84
+		ctx context.Context,
85
+		network string,
86
+		address string,
87
+	) (essentials.Conn, error) {
88
+		conn, err := s.network.DialContext(ctx, network, address)
89
+		if err != nil {
90
+			return nil, err
91
+		}
92
+
93
+		return NewScoutConn(conn, collected), nil
94
+	})
95
+
96
+	return client, collected
97
+}
98
+
99
+func NewScout(network Network, urls []string) Scout {
100
+	return Scout{
101
+		network: network,
102
+		urls:    urls,
103
+	}
104
+}

+ 57
- 0
mtglib/internal/doppel/scout_conn.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"bytes"
5
+	"encoding/binary"
6
+	"io"
7
+
8
+	"github.com/9seconds/mtg/v2/essentials"
9
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
10
+)
11
+
12
+type ScoutConn struct {
13
+	tls.Conn
14
+
15
+	results *ScoutConnCollected
16
+	rawBuf  *bytes.Buffer
17
+}
18
+
19
+func (s ScoutConn) Read(p []byte) (int, error) {
20
+	buf := &bytes.Buffer{}
21
+
22
+	for {
23
+		if n, err := s.rawBuf.Read(p); err == nil {
24
+			return n, nil
25
+		}
26
+
27
+		s.rawBuf.Reset()
28
+
29
+		recordType, length, err := tls.ReadRecord(s.Conn, buf)
30
+		if err != nil {
31
+			return 0, err
32
+		}
33
+
34
+		s.results.Add(recordType)
35
+		s.rawBuf.Write([]byte{recordType})
36
+		s.rawBuf.Write(tls.TLSVersion[:])
37
+
38
+		if err := binary.Write(s.rawBuf, binary.BigEndian, uint16(length)); err != nil {
39
+			return 0, err
40
+		}
41
+
42
+		if _, err := io.Copy(s.rawBuf, buf); err != nil {
43
+			return 0, err
44
+		}
45
+	}
46
+}
47
+
48
+func NewScoutConn(conn essentials.Conn, results *ScoutConnCollected) ScoutConn {
49
+	rawBuf := &bytes.Buffer{}
50
+	rawBuf.Grow(tls.MaxRecordSize)
51
+
52
+	return ScoutConn{
53
+		Conn:    tls.New(conn, false, false),
54
+		results: results,
55
+		rawBuf:  rawBuf,
56
+	}
57
+}

+ 29
- 0
mtglib/internal/doppel/scout_conn_collected.go Просмотреть файл

1
+package doppel
2
+
3
+import "time"
4
+
5
+const (
6
+	ScoutConnCollectedPreallocSize = 100
7
+)
8
+
9
+type ScoutConnResult struct {
10
+	timestamp  time.Time
11
+	recordType byte
12
+}
13
+
14
+type ScoutConnCollected struct {
15
+	data []ScoutConnResult
16
+}
17
+
18
+func (s *ScoutConnCollected) Add(record byte) {
19
+	s.data = append(s.data, ScoutConnResult{
20
+		timestamp:  time.Now(),
21
+		recordType: record,
22
+	})
23
+}
24
+
25
+func NewScoutConnCollected() *ScoutConnCollected {
26
+	return &ScoutConnCollected{
27
+		data: make([]ScoutConnResult, 0, ScoutConnCollectedPreallocSize),
28
+	}
29
+}

+ 42
- 0
mtglib/internal/doppel/scout_conn_collected_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"testing"
5
+	"time"
6
+
7
+	"github.com/9seconds/mtg/v2/mtglib/internal/tls"
8
+	"github.com/stretchr/testify/suite"
9
+)
10
+
11
+type ScoutConnCollectedTestSuite struct {
12
+	suite.Suite
13
+}
14
+
15
+func (suite *ScoutConnCollectedTestSuite) TestAddSingle() {
16
+	collected := NewScoutConnCollected()
17
+	collected.Add(tls.TypeApplicationData)
18
+
19
+	suite.Len(collected.data, 1)
20
+	suite.Equal(byte(tls.TypeApplicationData), collected.data[0].recordType)
21
+}
22
+
23
+func (suite *ScoutConnCollectedTestSuite) TestAddTimestampsAreMonotonic() {
24
+	collected := NewScoutConnCollected()
25
+
26
+	collected.Add(tls.TypeApplicationData)
27
+
28
+	time.Sleep(time.Microsecond)
29
+	collected.Add(tls.TypeApplicationData)
30
+
31
+	time.Sleep(time.Microsecond)
32
+	collected.Add(tls.TypeApplicationData)
33
+
34
+	for i := 1; i < len(collected.data); i++ {
35
+		suite.True(collected.data[i].timestamp.After(collected.data[i-1].timestamp))
36
+	}
37
+}
38
+
39
+func TestScoutConnCollected(t *testing.T) {
40
+	t.Parallel()
41
+	suite.Run(t, &ScoutConnCollectedTestSuite{})
42
+}

+ 39
- 0
mtglib/internal/doppel/scout_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"testing"
5
+
6
+	"github.com/stretchr/testify/suite"
7
+)
8
+
9
+type ScoutTestSuite struct {
10
+	TLSServerTestSuite
11
+
12
+	scout Scout
13
+}
14
+
15
+func (suite *ScoutTestSuite) SetupSuite() {
16
+	suite.TLSServerTestSuite.SetupSuite()
17
+
18
+	suite.scout = Scout{
19
+		network: suite.network,
20
+		urls:    suite.urls,
21
+	}
22
+}
23
+
24
+func (suite *ScoutTestSuite) TestCollectResults() {
25
+	durations, err := suite.scout.Learn(suite.ctx)
26
+	suite.NoError(err)
27
+	suite.Less(3, len(durations))
28
+}
29
+
30
+func (suite *ScoutTestSuite) TestCollectNothing() {
31
+	suite.ctxCancel()
32
+
33
+	_, err := suite.scout.Learn(suite.ctx)
34
+	suite.Error(err)
35
+}
36
+
37
+func TestScout(t *testing.T) {
38
+	suite.Run(t, &ScoutTestSuite{})
39
+}

+ 150
- 0
mtglib/internal/doppel/stats.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"math"
5
+	"math/rand/v2"
6
+	"time"
7
+)
8
+
9
+const (
10
+	StatsBisectTimes = 70
11
+	StatsLowK        = 0.01
12
+	StatsHighK       = 10.0
13
+
14
+	StatsDefaultK      = 0.6
15
+	StatsDefaultLambda = 0.002
16
+)
17
+
18
+// Stats is responsible for generating values that are distributed according
19
+// to some statistical distribution.
20
+//
21
+// It follows several ideas:
22
+//  1. Based on nginx and Cloudflare behaviour, even if server is eager
23
+//     to send a lot, they all start with small TLS packets that are
24
+//     approximately MTU-sized. After
25
+//  2. After ~40 TLS records, server considers TCP session as somewhat solid
26
+//     and reliable and ramps up to 4096.
27
+//  3. After ~20 TLS records more it jumps to the max 16384 bytes and keep
28
+//     this size as long as it can
29
+//  4. If there is no any byte within a connection for a longer time period,
30
+//     this counter resets.
31
+//
32
+// This is called Dynamic TLS Record Sizing
33
+//   - https://blog.cloudflare.com/optimizing-tls-over-tcp-to-reduce-latency/
34
+//   - https://community.f5.com/kb/technicalarticles/boosting-tls-performance-with-dynamic-record-sizing-on-big-ip/280798
35
+//   - https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/
36
+//
37
+// And this optimized for the very first byte, so web browsers could start to
38
+// render as early as possible, showing user some preliminary results, optimizing
39
+// for perceived latency.
40
+//
41
+// Since this is very typical for the website, we also aim for that.
42
+//
43
+// Another important idea is how delays between TLS packets are distributed.
44
+// In case of sending huge heavy content with max sized record, delays have
45
+// lognormal distribution. But a nature of a typical website shows that
46
+// it eagers to deliver as fast as it can in a few very first records and
47
+// could possibly slow down later.
48
+//
49
+// This is perfectly described by Weibull distribution:
50
+//   - https://en.wikipedia.org/wiki/Weibull_distribution
51
+//   - https://ieeexplore.ieee.org/document/6662948
52
+//   - https://www.researchgate.net/publication/224621285_Traffic_modelling_and_cost_optimization_for_transmitting_traffic_messages_over_a_hybrid_broadcast_and_cellular_network
53
+//   - https://ir.uitm.edu.my/id/eprint/105386/1/105386.pdf
54
+//
55
+// In other word, a combination of Dynamic TLS Record Sizing hints us for
56
+// Weibull distribution.
57
+type Stats struct {
58
+	sizeLastRequested time.Time
59
+	sizeCounter       int
60
+
61
+	// https://en.wikipedia.org/wiki/Shape_parameter
62
+	k float64
63
+	// https://en.wikipedia.org/wiki/Scale_parameter
64
+	lambda float64
65
+}
66
+
67
+func (d *Stats) Delay() time.Duration {
68
+	// u ∈ (0, 1], avoids ln(0)
69
+	u := 1.0 - rand.Float64()
70
+
71
+	// X = λ·(-ln U)^(1/k)
72
+	generated := d.lambda * math.Pow(-math.Log(u), 1.0/d.k)
73
+
74
+	// generated is in milliseconds
75
+	return time.Duration(generated * float64(time.Millisecond))
76
+}
77
+
78
+func (d *Stats) Size() int {
79
+	if time.Since(d.sizeLastRequested) > TLSRecordSizeResetAfter {
80
+		d.sizeCounter = 0
81
+	}
82
+
83
+	d.sizeLastRequested = time.Now()
84
+	d.sizeCounter++
85
+
86
+	switch {
87
+	case d.sizeCounter <= TLSCounterAccelAfter:
88
+		return TLSRecordSizeStart
89
+	case d.sizeCounter <= TLSCounterMaxAfter:
90
+		return TLSRecordSizeAccel
91
+	}
92
+
93
+	return TLSRecordSizeMax
94
+}
95
+
96
+func NewStats(durations []time.Duration) *Stats {
97
+	n := float64(len(durations))
98
+
99
+	// in milliseconds
100
+	durFloats := make([]float64, len(durations))
101
+	for i, v := range durations {
102
+		durFloats[i] = float64(v.Microseconds()) / 1000.0
103
+	}
104
+
105
+	// The bisection solves the standard Weibull MLE equation for shape
106
+	// parameter k. There is no any good formula for doing that so we
107
+	// approximate it by several bisections. The number of operations
108
+	// is statically defined by a constant.
109
+
110
+	sumLog := 0.0
111
+	for _, v := range durFloats {
112
+		sumLog += math.Log(v)
113
+	}
114
+
115
+	lowK := StatsLowK
116
+	highK := StatsHighK
117
+
118
+	for range StatsBisectTimes {
119
+		midK := (lowK + highK) / 2.0
120
+		sumXK := 0.0
121
+		sumXKLog := 0.0
122
+
123
+		for _, v := range durFloats {
124
+			xk := math.Pow(v, midK)
125
+			sumXK += xk
126
+			sumXKLog += xk * math.Log(v)
127
+		}
128
+
129
+		if (1.0/midK)+(sumLog/n)-(sumXKLog/sumXK) > 0 {
130
+			lowK = midK
131
+		} else {
132
+			highK = midK
133
+		}
134
+	}
135
+
136
+	k := (lowK + highK) / 2
137
+
138
+	sumXK := 0.0
139
+	for _, v := range durFloats {
140
+		sumXK += math.Pow(v, k)
141
+	}
142
+
143
+	// λ = (Σxᵢᵏ / n)^(1/k)
144
+	lambda := math.Pow(sumXK/n, 1.0/k)
145
+
146
+	return &Stats{
147
+		k:      k,
148
+		lambda: lambda,
149
+	}
150
+}

+ 194
- 0
mtglib/internal/doppel/stats_test.go Просмотреть файл

1
+package doppel
2
+
3
+import (
4
+	"math"
5
+	"math/rand/v2"
6
+	"testing"
7
+	"time"
8
+
9
+	"github.com/stretchr/testify/suite"
10
+)
11
+
12
+type StatsTestSuite struct {
13
+	suite.Suite
14
+}
15
+
16
+func (suite *StatsTestSuite) GenWeibull(k, lambda float64, n int, seed uint64) []time.Duration {
17
+	rng := rand.New(rand.NewPCG(seed, 0))
18
+	samples := make([]time.Duration, n)
19
+
20
+	for i := range samples {
21
+		u := 1.0 - rng.Float64()
22
+		ms := lambda * math.Pow(-math.Log(u), 1.0/k)
23
+		d := time.Duration(ms * float64(time.Millisecond))
24
+
25
+		if d < time.Microsecond {
26
+			time.Sleep(time.Microsecond)
27
+			d = time.Microsecond
28
+		}
29
+
30
+		samples[i] = d
31
+	}
32
+
33
+	return samples
34
+}
35
+
36
+func (suite *StatsTestSuite) TestNewStatsRecoverParameters() {
37
+	knownK := 1.5
38
+	knownLambda := 100.0
39
+
40
+	samples := suite.GenWeibull(knownK, knownLambda, 5000, 42)
41
+	stats := NewStats(samples)
42
+
43
+	suite.InDelta(knownK, stats.k, 0.1)
44
+	suite.InDelta(knownLambda, stats.lambda, 5.0)
45
+}
46
+
47
+func (suite *StatsTestSuite) TestNewStatsExponentialCase() {
48
+	// When k=1, Weibull reduces to exponential distribution.
49
+	knownK := 1.0
50
+	knownLambda := 50.0
51
+
52
+	samples := suite.GenWeibull(knownK, knownLambda, 5000, 123)
53
+	stats := NewStats(samples)
54
+
55
+	suite.InDelta(knownK, stats.k, 0.1)
56
+	suite.InDelta(knownLambda, stats.lambda, 5.0)
57
+}
58
+
59
+func (suite *StatsTestSuite) TestNewStatsSmallK() {
60
+	// k < 1 produces a heavy-tailed distribution typical for network delays.
61
+	// Lambda must be large enough so samples stay above microsecond precision
62
+	// after time.Duration round-trip.
63
+	knownK := 0.6
64
+	knownLambda := 100.0
65
+
66
+	samples := suite.GenWeibull(knownK, knownLambda, 10000, 99)
67
+	stats := NewStats(samples)
68
+
69
+	suite.InDelta(knownK, stats.k, 0.05)
70
+	suite.InDelta(knownLambda, stats.lambda, 5.0)
71
+}
72
+
73
+func (suite *StatsTestSuite) TestNewStatsLargeK() {
74
+	// k > 1: light tail, concentrated around the mode.
75
+	knownK := 5.0
76
+	knownLambda := 200.0
77
+
78
+	samples := suite.GenWeibull(knownK, knownLambda, 5000, 77)
79
+	stats := NewStats(samples)
80
+
81
+	suite.InDelta(knownK, stats.k, 0.3)
82
+	suite.InDelta(knownLambda, stats.lambda, 5.0)
83
+}
84
+
85
+func (suite *StatsTestSuite) TestDelayNonNegative() {
86
+	stats := &Stats{
87
+		k:      1.5,
88
+		lambda: 100.0,
89
+	}
90
+
91
+	for range 200 {
92
+		dur := stats.Delay()
93
+		suite.GreaterOrEqual(dur, time.Duration(0))
94
+	}
95
+}
96
+
97
+func (suite *StatsTestSuite) TestDelayDistributionMean() {
98
+	// Weibull mean = λ · Γ(1 + 1/k)
99
+	k := 2.0
100
+	lambda := 50.0
101
+	stats := &Stats{k: k, lambda: lambda}
102
+
103
+	n := 50000
104
+	sum := 0.0
105
+
106
+	for range n {
107
+		dur := stats.Delay()
108
+		sum += float64(dur) / float64(time.Millisecond)
109
+	}
110
+
111
+	sampleMean := sum / float64(n)
112
+	expectedMean := lambda * math.Gamma(1.0+1.0/k)
113
+
114
+	suite.InDelta(expectedMean, sampleMean, expectedMean*0.05)
115
+}
116
+
117
+func (suite *StatsTestSuite) TestNewStatsRoundTrip() {
118
+	// Estimate parameters from data, then verify that Delay samples
119
+	// from the fitted distribution have approximately the same mean.
120
+	knownK := 1.2
121
+	knownLambda := 80.0
122
+
123
+	samples := suite.GenWeibull(knownK, knownLambda, 5000, 555)
124
+	stats := NewStats(samples)
125
+
126
+	n := 50000
127
+	sum := 0.0
128
+
129
+	for range n {
130
+		dur := stats.Delay()
131
+		sum += float64(dur) / float64(time.Millisecond)
132
+	}
133
+
134
+	sampleMean := sum / float64(n)
135
+	expectedMean := knownLambda * math.Gamma(1.0+1.0/knownK)
136
+
137
+	suite.InDelta(expectedMean, sampleMean, expectedMean*0.05)
138
+}
139
+
140
+func (suite *StatsTestSuite) TestSizeStartPhase() {
141
+	stats := &Stats{k: 1.0, lambda: 1.0}
142
+
143
+	for range TLSCounterAccelAfter {
144
+		size := stats.Size()
145
+		suite.Equal(TLSRecordSizeStart, size)
146
+	}
147
+}
148
+
149
+func (suite *StatsTestSuite) TestSizeAccelPhase() {
150
+	stats := &Stats{k: 1.0, lambda: 1.0}
151
+
152
+	for range TLSCounterAccelAfter {
153
+		stats.Size()
154
+	}
155
+
156
+	for range TLSCounterMaxAfter - TLSCounterAccelAfter {
157
+		size := stats.Size()
158
+		suite.Equal(TLSRecordSizeAccel, size)
159
+	}
160
+}
161
+
162
+func (suite *StatsTestSuite) TestSizeMaxPhase() {
163
+	stats := &Stats{k: 1.0, lambda: 1.0}
164
+
165
+	for range TLSCounterMaxAfter {
166
+		stats.Size()
167
+	}
168
+
169
+	for range 20 {
170
+		size := stats.Size()
171
+		suite.Equal(TLSRecordSizeMax, size)
172
+	}
173
+}
174
+
175
+func (suite *StatsTestSuite) TestSizeResetsAfterInactivity() {
176
+	stats := &Stats{k: 1.0, lambda: 1.0}
177
+
178
+	// Advance past start phase.
179
+	for range TLSCounterMaxAfter {
180
+		stats.Size()
181
+	}
182
+
183
+	suite.Equal(TLSRecordSizeMax, stats.Size())
184
+
185
+	// Simulate inactivity by backdating sizeLastRequested.
186
+	stats.sizeLastRequested = time.Now().Add(-TLSRecordSizeResetAfter - time.Millisecond)
187
+
188
+	suite.Equal(TLSRecordSizeStart, stats.Size())
189
+}
190
+
191
+func TestStats(t *testing.T) {
192
+	t.Parallel()
193
+	suite.Run(t, &StatsTestSuite{})
194
+}

+ 88
- 0
mtglib/internal/tls/conn.go Просмотреть файл

1
+package tls
2
+
3
+import (
4
+	"bufio"
5
+	"bytes"
6
+
7
+	"github.com/9seconds/mtg/v2/essentials"
8
+)
9
+
10
+const (
11
+	SizeRecordType = 1
12
+	SizeVersion    = 2
13
+	SizeSize       = 2
14
+	SizeHeader     = SizeRecordType + SizeVersion + SizeSize
15
+
16
+	MaxRecordSize        = 16384
17
+	MaxRecordPayloadSize = MaxRecordSize - SizeHeader
18
+	DefaultBufferSize    = 4096
19
+
20
+	TypeChangeCipherSpec = 0x14
21
+	TypeHandshake        = 0x16
22
+	TypeApplicationData  = 0x17
23
+)
24
+
25
+var (
26
+	// TLS 1.2 is used for both TLS 1.2 and 1.3
27
+	TLSVersion = [SizeVersion]byte{3, 3}
28
+)
29
+
30
+// Conn presents an established TLS 1.3 connection, after handshake
31
+type Conn struct {
32
+	essentials.Conn
33
+
34
+	p *connPayload
35
+}
36
+
37
+type connPayload struct {
38
+	readBuf      bytes.Buffer
39
+	writeBuf     bytes.Buffer
40
+	connBuffered *bufio.Reader
41
+	read         bool
42
+	write        bool
43
+}
44
+
45
+func (c Conn) Write(p []byte) (int, error) {
46
+	if !c.p.write {
47
+		return c.Conn.Write(p)
48
+	}
49
+
50
+	return len(p), WriteRecord(c.Conn, p)
51
+}
52
+
53
+func (c Conn) Read(p []byte) (int, error) {
54
+	if !c.p.read {
55
+		return c.Conn.Read(p)
56
+	}
57
+
58
+	for {
59
+		if n, err := c.p.readBuf.Read(p); err == nil {
60
+			return n, nil
61
+		}
62
+
63
+		recordType, _, err := ReadRecord(c.p.connBuffered, &c.p.readBuf)
64
+		if err != nil {
65
+			return 0, err
66
+		}
67
+
68
+		if recordType != TypeApplicationData {
69
+			c.p.readBuf.Reset()
70
+		}
71
+	}
72
+}
73
+
74
+func New(conn essentials.Conn, read, write bool) Conn {
75
+	newConn := Conn{
76
+		Conn: conn,
77
+		p: &connPayload{
78
+			connBuffered: bufio.NewReaderSize(conn, DefaultBufferSize),
79
+			read:         read,
80
+			write:        write,
81
+		},
82
+	}
83
+
84
+	newConn.p.readBuf.Grow(DefaultBufferSize)
85
+	newConn.p.writeBuf.Grow(DefaultBufferSize)
86
+
87
+	return newConn
88
+}

+ 160
- 0
mtglib/internal/tls/conn_test.go Просмотреть файл

1
+package tls
2
+
3
+import (
4
+	"io"
5
+	"testing"
6
+
7
+	"github.com/9seconds/mtg/v2/internal/testlib"
8
+	"github.com/stretchr/testify/mock"
9
+	"github.com/stretchr/testify/suite"
10
+)
11
+
12
+type ConnTestSuite struct {
13
+	suite.Suite
14
+
15
+	connMock *testlib.EssentialsConnMock
16
+}
17
+
18
+func (suite *ConnTestSuite) SetupTest() {
19
+	suite.connMock = &testlib.EssentialsConnMock{}
20
+}
21
+
22
+func (suite *ConnTestSuite) TearDownTest() {
23
+	suite.connMock.AssertExpectations(suite.T())
24
+}
25
+
26
+func (suite *ConnTestSuite) feedRead(raw []byte) {
27
+	suite.connMock.
28
+		On("Read", mock.AnythingOfType("[]uint8")).
29
+		Run(func(args mock.Arguments) {
30
+			copy(args.Get(0).([]byte), raw)
31
+		}).
32
+		Return(len(raw), nil).
33
+		Once()
34
+	suite.connMock.
35
+		On("Read", mock.AnythingOfType("[]uint8")).
36
+		Return(0, io.EOF).
37
+		Maybe()
38
+}
39
+
40
+func (suite *ConnTestSuite) TestReadTLSEnabled() {
41
+	payload := []byte("hello world")
42
+	suite.feedRead(MakeTLSRecord(0x17, payload))
43
+
44
+	conn := New(suite.connMock, true, false)
45
+
46
+	buf := make([]byte, 128)
47
+	n, err := conn.Read(buf)
48
+
49
+	suite.NoError(err)
50
+	suite.Equal(payload, buf[:n])
51
+}
52
+
53
+func (suite *ConnTestSuite) TestReadTLSSkipsNonApplicationData() {
54
+	raw := append(
55
+		MakeTLSRecord(0x14, []byte{1}),
56
+		MakeTLSRecord(0x17, []byte("real data"))...,
57
+	)
58
+	suite.feedRead(raw)
59
+
60
+	conn := New(suite.connMock, true, false)
61
+
62
+	buf := make([]byte, 128)
63
+	n, err := conn.Read(buf)
64
+
65
+	suite.NoError(err)
66
+	suite.Equal([]byte("real data"), buf[:n])
67
+}
68
+
69
+func (suite *ConnTestSuite) TestReadTLSMultipleRecords() {
70
+	raw := append(
71
+		MakeTLSRecord(0x17, []byte("first")),
72
+		MakeTLSRecord(0x17, []byte("second"))...,
73
+	)
74
+	suite.feedRead(raw)
75
+
76
+	conn := New(suite.connMock, true, false)
77
+	buf := make([]byte, 128)
78
+
79
+	n, err := conn.Read(buf)
80
+	suite.NoError(err)
81
+	suite.Equal([]byte("first"), buf[:n])
82
+
83
+	n, err = conn.Read(buf)
84
+	suite.NoError(err)
85
+	suite.Equal([]byte("second"), buf[:n])
86
+}
87
+
88
+func (suite *ConnTestSuite) TestReadTLSSmallBuffer() {
89
+	payload := []byte("hello world, this is a longer payload")
90
+	suite.feedRead(MakeTLSRecord(0x17, payload))
91
+
92
+	conn := New(suite.connMock, true, false)
93
+
94
+	small := make([]byte, 5)
95
+	n, err := conn.Read(small)
96
+	suite.NoError(err)
97
+	suite.Equal(payload[:5], small[:n])
98
+
99
+	rest := make([]byte, 128)
100
+	n, err = conn.Read(rest)
101
+	suite.NoError(err)
102
+	suite.Equal(payload[5:], rest[:n])
103
+}
104
+
105
+func (suite *ConnTestSuite) TestReadPassthrough() {
106
+	data := []byte("raw bytes")
107
+
108
+	suite.connMock.
109
+		On("Read", mock.AnythingOfType("[]uint8")).
110
+		Run(func(args mock.Arguments) {
111
+			copy(args.Get(0).([]byte), data)
112
+		}).
113
+		Return(len(data), nil).
114
+		Once()
115
+
116
+	conn := New(suite.connMock, false, false)
117
+
118
+	buf := make([]byte, 128)
119
+	n, err := conn.Read(buf)
120
+
121
+	suite.NoError(err)
122
+	suite.Equal(data, buf[:n])
123
+}
124
+
125
+func (suite *ConnTestSuite) TestWritePassthrough() {
126
+	data := []byte("outgoing data")
127
+
128
+	suite.connMock.
129
+		On("Write", mock.AnythingOfType("[]uint8")).
130
+		Return(len(data), nil).
131
+		Once()
132
+
133
+	conn := New(suite.connMock, false, false)
134
+
135
+	n, err := conn.Write(data)
136
+
137
+	suite.NoError(err)
138
+	suite.Equal(len(data), n)
139
+}
140
+
141
+func (suite *ConnTestSuite) TestWriteTLSEnabled() {
142
+	data := []byte("outgoing data")
143
+
144
+	suite.connMock.
145
+		On("Write", mock.AnythingOfType("[]uint8")).
146
+		Return(len(data), nil).
147
+		Once()
148
+
149
+	conn := New(suite.connMock, false, true)
150
+
151
+	n, err := conn.Write(data)
152
+
153
+	suite.NoError(err)
154
+	suite.Equal(len(data), n)
155
+}
156
+
157
+func TestConn(t *testing.T) {
158
+	t.Parallel()
159
+	suite.Run(t, &ConnTestSuite{})
160
+}

+ 30
- 0
mtglib/internal/tls/init_test.go Просмотреть файл

1
+package tls
2
+
3
+import (
4
+	"encoding/binary"
5
+
6
+	"github.com/stretchr/testify/mock"
7
+)
8
+
9
+type WriterMock struct {
10
+	mock.Mock
11
+}
12
+
13
+func (m *WriterMock) Write(p []byte) (int, error) {
14
+	args := m.Called(p)
15
+	return args.Int(0), args.Error(1)
16
+}
17
+
18
+// makeTLSRecord builds a raw TLS record from hardcoded offsets:
19
+// type(1) + version(2, {3,3}) + length(2, big-endian) + payload.
20
+func MakeTLSRecord(recordType byte, payload []byte) []byte {
21
+	buf := make([]byte, 5+len(payload))
22
+
23
+	buf[0] = recordType
24
+	buf[1] = 3
25
+	buf[2] = 3
26
+	binary.BigEndian.PutUint16(buf[3:5], uint16(len(payload)))
27
+	copy(buf[5:], payload)
28
+
29
+	return buf
30
+}

+ 48
- 0
mtglib/internal/tls/utils.go Просмотреть файл

1
+package tls
2
+
3
+import (
4
+	"bytes"
5
+	"encoding/binary"
6
+	"fmt"
7
+	"io"
8
+)
9
+
10
+func ReadRecord(r io.Reader, w io.Writer) (byte, int64, error) {
11
+	buf := [SizeHeader]byte{}
12
+
13
+	if _, err := io.ReadFull(r, buf[:]); err != nil {
14
+		return 0, 0, err
15
+	}
16
+
17
+	pVer := buf[SizeRecordType:]
18
+	pLen := pVer[SizeVersion:]
19
+
20
+	if !bytes.Equal(TLSVersion[:], pVer[:SizeVersion]) {
21
+		return 0, 0, fmt.Errorf("incorrect tls version %v", pVer)
22
+	}
23
+
24
+	length := int64(binary.BigEndian.Uint16(pLen[:SizeSize]))
25
+	_, err := io.CopyN(w, r, length)
26
+
27
+	return buf[0], length, err
28
+}
29
+
30
+func WriteRecord(w io.Writer, payload []byte) error {
31
+	buf := [MaxRecordSize]byte{}
32
+	buf[0] = TypeApplicationData
33
+
34
+	bufV := buf[SizeRecordType:]
35
+	copy(bufV[:SizeVersion], TLSVersion[:])
36
+
37
+	bufS := bufV[SizeVersion:]
38
+	binary.BigEndian.PutUint16(bufS[:SizeSize], uint16(len(payload)))
39
+
40
+	bufP := buf[SizeHeader:]
41
+	if n := copy(bufP, payload); n != len(payload) {
42
+		return fmt.Errorf("copied %d bytes of payload instead of %d", n, len(payload))
43
+	}
44
+
45
+	_, err := w.Write(buf[:SizeHeader+len(payload)])
46
+
47
+	return err
48
+}

+ 125
- 0
mtglib/internal/tls/utils_test.go Просмотреть файл

1
+package tls
2
+
3
+import (
4
+	"bytes"
5
+	"encoding/binary"
6
+	"errors"
7
+	"testing"
8
+
9
+	"github.com/stretchr/testify/mock"
10
+	"github.com/stretchr/testify/suite"
11
+)
12
+
13
+type UtilsTestSuite struct {
14
+	suite.Suite
15
+
16
+	dst *bytes.Buffer
17
+}
18
+
19
+func (suite *UtilsTestSuite) SetupTest() {
20
+	suite.dst = &bytes.Buffer{}
21
+}
22
+
23
+func (suite *UtilsTestSuite) TestReadRecord() {
24
+	payload := []byte("hello world")
25
+	raw := MakeTLSRecord(0x17, payload)
26
+
27
+	recordType, length, err := ReadRecord(bytes.NewReader(raw), suite.dst)
28
+
29
+	suite.NoError(err)
30
+	suite.Equal(byte(0x17), recordType)
31
+	suite.Equal(int64(len(payload)), length)
32
+	suite.Equal(payload, suite.dst.Bytes())
33
+}
34
+
35
+func (suite *UtilsTestSuite) TestReadRecordChangeCipherSpec() {
36
+	payload := []byte{1}
37
+	raw := MakeTLSRecord(0x14, payload)
38
+
39
+	recordType, length, err := ReadRecord(bytes.NewReader(raw), suite.dst)
40
+
41
+	suite.NoError(err)
42
+	suite.Equal(byte(0x14), recordType)
43
+	suite.Equal(int64(1), length)
44
+}
45
+
46
+func (suite *UtilsTestSuite) TestReadRecordRejectsWrongVersion() {
47
+	record := []byte{0x17, 3, 1, 0, 5, 0, 0, 0, 0, 0}
48
+
49
+	_, _, err := ReadRecord(bytes.NewReader(record), suite.dst)
50
+	suite.ErrorContains(err, "incorrect tls version")
51
+}
52
+
53
+func (suite *UtilsTestSuite) TestReadRecordEmptyReader() {
54
+	_, _, err := ReadRecord(bytes.NewReader(nil), suite.dst)
55
+	suite.Error(err)
56
+}
57
+
58
+func (suite *UtilsTestSuite) TestReadRecordTruncatedHeader() {
59
+	_, _, err := ReadRecord(bytes.NewReader([]byte{0x17, 3}), suite.dst)
60
+	suite.Error(err)
61
+}
62
+
63
+func (suite *UtilsTestSuite) TestReadRecordTruncatedPayload() {
64
+	raw := MakeTLSRecord(0x17, []byte("full payload"))
65
+	truncated := raw[:5+3]
66
+
67
+	_, _, err := ReadRecord(bytes.NewReader(truncated), suite.dst)
68
+	suite.Error(err)
69
+}
70
+
71
+func (suite *UtilsTestSuite) TestWriteRecord() {
72
+	payload := []byte("hello world")
73
+
74
+	err := WriteRecord(suite.dst, payload)
75
+	suite.NoError(err)
76
+
77
+	written := suite.dst.Bytes()
78
+	suite.Equal(byte(0x17), written[0])
79
+	suite.Equal([]byte{3, 3}, written[1:3])
80
+
81
+	length := binary.BigEndian.Uint16(written[3:5])
82
+	suite.Equal(uint16(len(payload)), length)
83
+	suite.Equal(payload, written[5:])
84
+}
85
+
86
+func (suite *UtilsTestSuite) TestWriteRecordRoundTrip() {
87
+	payload := []byte("round trip test")
88
+
89
+	var wire bytes.Buffer
90
+
91
+	err := WriteRecord(&wire, payload)
92
+	suite.NoError(err)
93
+
94
+	var recovered bytes.Buffer
95
+
96
+	recordType, length, err := ReadRecord(&wire, &recovered)
97
+
98
+	suite.NoError(err)
99
+	suite.Equal(byte(0x17), recordType)
100
+	suite.Equal(int64(len(payload)), length)
101
+	suite.Equal(payload, recovered.Bytes())
102
+}
103
+
104
+func (suite *UtilsTestSuite) TestWriteRecordPropagatesError() {
105
+	m := &WriterMock{}
106
+	m.
107
+		On("Write", mock.AnythingOfType("[]uint8")).
108
+		Once().
109
+		Return(0, errors.New("dist full"))
110
+
111
+	err := WriteRecord(m, []byte("data"))
112
+	suite.Error(err)
113
+
114
+	m.AssertExpectations(suite.T())
115
+}
116
+
117
+func (suite *UtilsTestSuite) TestWriteRecordPayloadTooLarge() {
118
+	err := WriteRecord(suite.dst, make([]byte, MaxRecordPayloadSize+1))
119
+	suite.Error(err)
120
+}
121
+
122
+func TestUtils(t *testing.T) {
123
+	t.Parallel()
124
+	suite.Run(t, &UtilsTestSuite{})
125
+}

Загрузка…
Отмена
Сохранить