Browse Source

Add load balancing network dialer

tags/v2.0.0-rc1
9seconds 5 years ago
parent
commit
7002e4cd09

+ 195
- 0
mtglib/network/circuit_breaker.go View File

@@ -0,0 +1,195 @@
1
+package network
2
+
3
+import (
4
+	"context"
5
+	"net"
6
+	"sync/atomic"
7
+	"time"
8
+)
9
+
10
+const (
11
+	circuitBreakerStateClosed uint32 = iota
12
+	circuitBreakerStateHalfOpened
13
+	circuitBreakerStateOpened
14
+)
15
+
16
+type circuitBreakerDialer struct {
17
+	Dialer
18
+
19
+	stateMutexChan chan bool
20
+
21
+	halfOpenTimer        *time.Timer
22
+	failuresCleanupTimer *time.Timer
23
+
24
+	state            uint32
25
+	halfOpenAttempts uint32
26
+	failuresCount    uint32
27
+
28
+	openThreshold        uint32
29
+	halfOpenTimeout      time.Duration
30
+	resetFailuresTimeout time.Duration
31
+}
32
+
33
+func (c *circuitBreakerDialer) Dial(network, address string) (net.Conn, error) {
34
+	return c.DialContext(context.Background(), network, address)
35
+}
36
+
37
+func (c *circuitBreakerDialer) DialContext(ctx context.Context,
38
+	network, address string) (net.Conn, error) {
39
+	switch atomic.LoadUint32(&c.state) {
40
+	case circuitBreakerStateClosed:
41
+		return c.doClosed(ctx, network, address)
42
+	case circuitBreakerStateHalfOpened:
43
+		return c.doHalfOpened(ctx, network, address)
44
+	default:
45
+		return nil, ErrCircuitBreakerOpened
46
+	}
47
+}
48
+
49
+func (c *circuitBreakerDialer) doClosed(ctx context.Context,
50
+	network, address string) (net.Conn, error) {
51
+	conn, err := c.Dialer.DialContext(ctx, network, address)
52
+
53
+	select {
54
+	case <-ctx.Done():
55
+		if conn != nil {
56
+			conn.Close()
57
+		}
58
+
59
+		return nil, ctx.Err()
60
+	case c.stateMutexChan <- true:
61
+		defer func() {
62
+			<-c.stateMutexChan
63
+		}()
64
+	}
65
+
66
+	if err == nil {
67
+		c.switchState(circuitBreakerStateClosed)
68
+
69
+		return conn, err
70
+	}
71
+
72
+	c.failuresCount++
73
+
74
+	if c.state == circuitBreakerStateClosed && c.failuresCount > c.openThreshold {
75
+		c.switchState(circuitBreakerStateOpened)
76
+	}
77
+
78
+	return conn, err
79
+}
80
+
81
+func (c *circuitBreakerDialer) doHalfOpened(ctx context.Context, network, address string) (net.Conn, error) {
82
+	if !atomic.CompareAndSwapUint32(&c.halfOpenAttempts, 0, 1) {
83
+		return nil, ErrCircuitBreakerOpened
84
+	}
85
+
86
+	conn, err := c.Dialer.DialContext(ctx, network, address)
87
+
88
+	select {
89
+	case <-ctx.Done():
90
+		if conn != nil {
91
+			conn.Close()
92
+		}
93
+
94
+		return nil, ctx.Err()
95
+	case c.stateMutexChan <- true:
96
+		defer func() {
97
+			<-c.stateMutexChan
98
+		}()
99
+	}
100
+
101
+	if c.state != circuitBreakerStateHalfOpened {
102
+		return conn, err
103
+	}
104
+
105
+	if err == nil {
106
+		c.switchState(circuitBreakerStateClosed)
107
+	} else {
108
+		c.switchState(circuitBreakerStateOpened)
109
+	}
110
+
111
+	return conn, err
112
+}
113
+
114
+func (c *circuitBreakerDialer) switchState(state uint32) {
115
+	switch state {
116
+	case circuitBreakerStateClosed:
117
+		c.stopTimer(&c.halfOpenTimer)
118
+		c.ensureTimer(&c.failuresCleanupTimer, c.resetFailuresTimeout, c.resetFailures)
119
+	case circuitBreakerStateHalfOpened:
120
+		c.stopTimer(&c.failuresCleanupTimer)
121
+		c.stopTimer(&c.halfOpenTimer)
122
+	case circuitBreakerStateOpened:
123
+		c.stopTimer(&c.failuresCleanupTimer)
124
+		c.ensureTimer(&c.halfOpenTimer, c.halfOpenTimeout, c.tryHalfOpen)
125
+	}
126
+
127
+	c.failuresCount = 0
128
+
129
+	atomic.StoreUint32(&c.halfOpenAttempts, 0)
130
+	atomic.StoreUint32(&c.state, state)
131
+}
132
+
133
+func (c *circuitBreakerDialer) resetFailures() {
134
+	c.stateMutexChan <- true
135
+
136
+	defer func() {
137
+		<-c.stateMutexChan
138
+	}()
139
+
140
+	c.stopTimer(&c.failuresCleanupTimer)
141
+
142
+	if c.state == circuitBreakerStateClosed {
143
+		c.switchState(circuitBreakerStateClosed)
144
+	}
145
+}
146
+
147
+func (c *circuitBreakerDialer) tryHalfOpen() {
148
+	c.stateMutexChan <- true
149
+
150
+	defer func() {
151
+		<-c.stateMutexChan
152
+	}()
153
+
154
+	if c.state == circuitBreakerStateOpened {
155
+		c.switchState(circuitBreakerStateHalfOpened)
156
+	}
157
+}
158
+
159
+func (c *circuitBreakerDialer) stopTimer(timerRef **time.Timer) {
160
+	timer := *timerRef
161
+
162
+	if timer == nil {
163
+		return
164
+	}
165
+
166
+	timer.Stop()
167
+
168
+	select {
169
+	case <-timer.C:
170
+	default:
171
+	}
172
+
173
+	*timerRef = nil
174
+}
175
+
176
+func (c *circuitBreakerDialer) ensureTimer(timerRef **time.Timer,
177
+	timeout time.Duration, callback func()) {
178
+	if *timerRef == nil {
179
+		*timerRef = time.AfterFunc(timeout, callback)
180
+	}
181
+}
182
+
183
+func newCircuitBreakerDialer(baseDialer Dialer,
184
+	openThreshold uint32, halfOpenTimeout, resetFailuresTimeout time.Duration) Dialer {
185
+	cb := &circuitBreakerDialer{
186
+		Dialer:               baseDialer,
187
+		openThreshold:        openThreshold,
188
+		halfOpenTimeout:      halfOpenTimeout,
189
+		resetFailuresTimeout: resetFailuresTimeout,
190
+	}
191
+
192
+	cb.switchState(circuitBreakerStateClosed)
193
+
194
+	return cb
195
+}

+ 0
- 9
mtglib/network/consts.go View File

@@ -1,9 +0,0 @@
1
-package network
2
-
3
-import "time"
4
-
5
-const (
6
-	DefaultTimeout     = 10 * time.Second
7
-	DefaultHTTPTimeout = DefaultTimeout
8
-	DefaultBufferSize  = 4096
9
-)

+ 25
- 0
mtglib/network/init.go View File

@@ -0,0 +1,25 @@
1
+package network
2
+
3
+import (
4
+	"context"
5
+	"errors"
6
+	"net"
7
+	"time"
8
+)
9
+
10
+const (
11
+	DefaultTimeout     = 10 * time.Second
12
+	DefaultDNSTimeout  = time.Second
13
+	DefaultHTTPTimeout = DefaultTimeout
14
+	DefaultBufferSize  = 4096
15
+)
16
+
17
+var (
18
+	ErrCircuitBreakerOpened     = errors.New("circuit breaker is opened")
19
+	ErrCannotDialWithAllProxies = errors.New("cannot dial with all proxies")
20
+)
21
+
22
+type Dialer interface {
23
+	Dial(network, address string) (net.Conn, error)
24
+	DialContext(ctx context.Context, network, address string) (net.Conn, error)
25
+}

+ 19
- 0
mtglib/network/init_test.go View File

@@ -1,13 +1,32 @@
1 1
 package network_test
2 2
 
3 3
 import (
4
+	"context"
5
+	"net"
4 6
 	"net/http/httptest"
5 7
 	"strings"
6 8
 
7 9
 	"github.com/mccutchen/go-httpbin/httpbin"
10
+	"github.com/stretchr/testify/mock"
8 11
 	"github.com/stretchr/testify/suite"
9 12
 )
10 13
 
14
+type DialerMock struct {
15
+	mock.Mock
16
+}
17
+
18
+func (d *DialerMock) Dial(network, address string) (net.Conn, error) {
19
+	args := d.Called(network, address)
20
+
21
+	return args.Get(0).(net.Conn), args.Error(1)
22
+}
23
+
24
+func (d *DialerMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
25
+	args := d.Called(ctx, network, address)
26
+
27
+	return args.Get(0).(net.Conn), args.Error(1)
28
+}
29
+
11 30
 type HTTPServerTestSuite struct {
12 31
 	suite.Suite
13 32
 

+ 0
- 11
mtglib/network/interfaces.go View File

@@ -1,11 +0,0 @@
1
-package network
2
-
3
-import (
4
-	"context"
5
-	"net"
6
-)
7
-
8
-type Dialer interface {
9
-	Dial(network, address string) (net.Conn, error)
10
-	DialContext(ctx context.Context, network, address string) (net.Conn, error)
11
-}

+ 55
- 0
mtglib/network/load_balanced.go View File

@@ -0,0 +1,55 @@
1
+package network
2
+
3
+import (
4
+	"context"
5
+	"math/rand"
6
+	"net"
7
+	"net/url"
8
+)
9
+
10
+type loadBalancedDialer struct {
11
+	dialers []Dialer
12
+}
13
+
14
+func (l loadBalancedDialer) Dial(network, address string) (net.Conn, error) {
15
+	return l.DialContext(context.Background(), network, address)
16
+}
17
+
18
+func (l loadBalancedDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
19
+	length := len(l.dialers)
20
+	start := rand.Intn(length)
21
+	moved := false
22
+
23
+	for i := start; i != start || !moved; i = (i + 1) % length {
24
+		moved = true
25
+		if conn, err := l.dialers[i].DialContext(ctx, network, address); err == nil {
26
+			return conn, nil
27
+		}
28
+	}
29
+
30
+	return nil, ErrCannotDialWithAllProxies
31
+}
32
+
33
+func NewLoadBalancedDialer(baseDialer Dialer, proxyURLs []*url.URL) (Dialer, error) {
34
+	switch len(proxyURLs) {
35
+	case 0:
36
+		return baseDialer, nil
37
+	case 1:
38
+		return NewSocks5Dialer(baseDialer, proxyURLs[0])
39
+	}
40
+
41
+	dialers := []Dialer{}
42
+
43
+	for _, u := range proxyURLs {
44
+		dialer, err := NewSocks5Dialer(newProxyDialer(baseDialer, u), u)
45
+		if err != nil {
46
+			return nil, err
47
+		}
48
+
49
+		dialers = append(dialers, dialer)
50
+	}
51
+
52
+	return loadBalancedDialer{
53
+		dialers: dialers,
54
+	}, nil
55
+}

+ 1
- 1
mtglib/network/network.go View File

@@ -113,7 +113,7 @@ func NewNetwork(dialer Dialer, dohHostname string, httpTimeout time.Duration) (*
113 113
 	}
114 114
 
115 115
 	dohHTTPClient := &http.Client{
116
-		Timeout: httpTimeout,
116
+		Timeout: DefaultDNSTimeout,
117 117
 		Transport: &http.Transport{
118 118
 			DialContext: dialer.DialContext,
119 119
 		},

+ 43
- 0
mtglib/network/proxy_dialer.go View File

@@ -0,0 +1,43 @@
1
+package network
2
+
3
+import (
4
+	"net/url"
5
+	"strconv"
6
+	"time"
7
+)
8
+
9
+const (
10
+	ProxyDialerOpenThreshold        = 5
11
+	ProxyDialerHalfOpenTimeout      = time.Minute
12
+	ProxyDialerResetFailuresTimeout = 10 * time.Second
13
+)
14
+
15
+func newProxyDialer(baseDialer Dialer, proxyURL *url.URL) Dialer {
16
+	params := proxyURL.Query()
17
+
18
+	var (
19
+		openThreshold        uint32 = ProxyDialerOpenThreshold
20
+		halfOpenTimeout             = ProxyDialerHalfOpenTimeout
21
+		resetFailuresTimeout        = ProxyDialerResetFailuresTimeout
22
+	)
23
+
24
+	if param := params.Get("open_threshold"); param != "" {
25
+		if intNum, err := strconv.ParseUint(param, 10, 32); err == nil {
26
+			openThreshold = uint32(intNum)
27
+		}
28
+	}
29
+
30
+	if param := params.Get("half_open_timeout"); param != "" {
31
+		if dur, err := time.ParseDuration(param); err == nil && dur > 0 {
32
+			halfOpenTimeout = dur
33
+		}
34
+	}
35
+
36
+	if param := params.Get("reset_failures_timeout"); param != "" {
37
+		if dur, err := time.ParseDuration(param); err == nil && dur > 0 {
38
+			resetFailuresTimeout = dur
39
+		}
40
+	}
41
+
42
+	return newCircuitBreakerDialer(baseDialer, openThreshold, halfOpenTimeout, resetFailuresTimeout)
43
+}

+ 2
- 2
mtglib/network/socks5.go View File

@@ -7,8 +7,8 @@ import (
7 7
 	"golang.org/x/net/proxy"
8 8
 )
9 9
 
10
-func NewSocks5Dialer(proxyURL *url.URL, base Dialer) (Dialer, error) {
11
-	rv, err := proxy.FromURL(proxyURL, base)
10
+func NewSocks5Dialer(baseDialer Dialer, proxyURL *url.URL) (Dialer, error) {
11
+	rv, err := proxy.FromURL(proxyURL, baseDialer)
12 12
 	if err != nil {
13 13
 		return nil, fmt.Errorf("cannot initialize socks5 proxy dialer: %w", err)
14 14
 	}

+ 2
- 2
mtglib/network/socks5_test.go View File

@@ -47,7 +47,7 @@ func (suite *Socks5TestSuite) TestRequestFailed() {
47 47
 		User:   url.UserPassword("user2", "password"),
48 48
 		Host:   suite.socksListener.Addr().String(),
49 49
 	}
50
-	dialer, _ := network.NewSocks5Dialer(proxyURL, suite.baseDialer)
50
+	dialer, _ := network.NewSocks5Dialer(suite.baseDialer, proxyURL)
51 51
 
52 52
 	httpClient := http.Client{
53 53
 		Transport: &http.Transport{
@@ -66,7 +66,7 @@ func (suite *Socks5TestSuite) TestRequestOk() {
66 66
 		User:   url.UserPassword("user", "password"),
67 67
 		Host:   suite.socksListener.Addr().String(),
68 68
 	}
69
-	dialer, _ := network.NewSocks5Dialer(proxyURL, suite.baseDialer)
69
+	dialer, _ := network.NewSocks5Dialer(suite.baseDialer, proxyURL)
70 70
 
71 71
 	httpClient := http.Client{
72 72
 		Transport: &http.Transport{

Loading…
Cancel
Save