瀏覽代碼

More reasonable shutdowns

tags/1.0^2
9seconds 6 年之前
父節點
當前提交
701f62ed4c
共有 8 個檔案被更改,包括 115 行新增41 行删除
  1. 11
    3
      cli/proxy.go
  2. 6
    5
      config/config.go
  3. 11
    6
      config/global_ips.go
  4. 23
    23
      main.go
  5. 11
    3
      proxy/proxy.go
  6. 6
    1
      stats/stats.go
  7. 24
    0
      utils/signal_context.go
  8. 23
    0
      utils/signal_context_windows.go

+ 11
- 3
cli/proxy.go 查看文件

@@ -15,9 +15,12 @@ import (
15 15
 	"github.com/9seconds/mtg/proxy"
16 16
 	"github.com/9seconds/mtg/stats"
17 17
 	"github.com/9seconds/mtg/telegram"
18
+	"github.com/9seconds/mtg/utils"
18 19
 )
19 20
 
20 21
 func Proxy() error {
22
+	ctx := utils.GetSignalContext()
23
+
21 24
 	atom := zap.NewAtomicLevel()
22 25
 	switch {
23 26
 	case config.C.Debug:
@@ -37,7 +40,7 @@ func Proxy() error {
37 40
 	zap.ReplaceGlobals(logger)
38 41
 	defer logger.Sync() // nolint: errcheck
39 42
 
40
-	if err := config.InitPublicAddress(); err != nil {
43
+	if err := config.InitPublicAddress(ctx); err != nil {
41 44
 		Fatal(err.Error())
42 45
 	}
43 46
 	zap.S().Debugw("Configuration", "config", config.C)
@@ -61,16 +64,21 @@ func Proxy() error {
61 64
 	if err := antireplay.Init(); err != nil {
62 65
 		Fatal(err.Error())
63 66
 	}
64
-	if err := stats.Init(); err != nil {
67
+	if err := stats.Init(ctx); err != nil {
65 68
 		Fatal(err.Error())
66 69
 	}
67 70
 	proxyListener, err := net.Listen("tcp", config.C.ListenAddr.String())
68 71
 	if err != nil {
69 72
 		Fatal(err.Error())
70 73
 	}
74
+	go func() {
75
+		<-ctx.Done()
76
+		proxyListener.Close()
77
+	}()
71 78
 
72 79
 	app := &proxy.Proxy{
73
-		Logger: zap.S().Named("proxy"),
80
+		Logger:  zap.S().Named("proxy"),
81
+		Context: ctx,
74 82
 	}
75 83
 	if len(config.C.AdTag) == 0 {
76 84
 		app.TelegramProtocolMaker = obfuscated2.MakeTelegramProtocol

+ 6
- 5
config/config.go 查看文件

@@ -2,6 +2,7 @@ package config
2 2
 
3 3
 import (
4 4
 	"bytes"
5
+	"context"
5 6
 	"encoding/json"
6 7
 	"net"
7 8
 	"strconv"
@@ -140,14 +141,14 @@ func (c Config) String() string {
140 141
 	return string(data)
141 142
 }
142 143
 
143
-type ConfigOpt struct {
144
+type Opt struct {
144 145
 	Option OptionType
145 146
 	Value  interface{}
146 147
 }
147 148
 
148 149
 var C = Config{}
149 150
 
150
-func Init(options ...ConfigOpt) error { // nolint: gocyclo
151
+func Init(options ...Opt) error { // nolint: gocyclo
151 152
 	for _, opt := range options {
152 153
 		switch opt.Option {
153 154
 		case OptionTypeDebug:
@@ -222,7 +223,7 @@ func Init(options ...ConfigOpt) error { // nolint: gocyclo
222 223
 	return nil
223 224
 }
224 225
 
225
-func InitPublicAddress() error {
226
+func InitPublicAddress(ctx context.Context) error {
226 227
 	if C.PublicIPv4Addr.Port == 0 {
227 228
 		C.PublicIPv4Addr.Port = C.ListenAddr.Port
228 229
 	}
@@ -232,7 +233,7 @@ func InitPublicAddress() error {
232 233
 
233 234
 	foundAddress := C.PublicIPv4Addr.IP != nil || C.PublicIPv6Addr.IP != nil
234 235
 	if C.PublicIPv4Addr.IP == nil {
235
-		ip, err := getGlobalIPv4()
236
+		ip, err := getGlobalIPv4(ctx)
236 237
 		if err != nil {
237 238
 			zap.S().Warnw("Cannot resolve public address", "error", err)
238 239
 		} else {
@@ -241,7 +242,7 @@ func InitPublicAddress() error {
241 242
 		}
242 243
 	}
243 244
 	if C.PublicIPv6Addr.IP == nil {
244
-		ip, err := getGlobalIPv6()
245
+		ip, err := getGlobalIPv6(ctx)
245 246
 		if err != nil {
246 247
 			zap.S().Warnw("Cannot resolve public address", "error", err)
247 248
 		} else {

+ 11
- 6
config/global_ips.go 查看文件

@@ -17,23 +17,23 @@ const (
17 17
 	ifconfigTimeout = 10 * time.Second
18 18
 )
19 19
 
20
-func getGlobalIPv4() (net.IP, error) {
21
-	ip, err := fetchIP("tcp4")
20
+func getGlobalIPv4(ctx context.Context) (net.IP, error) {
21
+	ip, err := fetchIP(ctx, "tcp4")
22 22
 	if err != nil || ip.To4() == nil {
23 23
 		return nil, errors.Annotate(err, "Cannot find public ipv4 address")
24 24
 	}
25 25
 	return ip, nil
26 26
 }
27 27
 
28
-func getGlobalIPv6() (net.IP, error) {
29
-	ip, err := fetchIP("tcp6")
28
+func getGlobalIPv6(ctx context.Context) (net.IP, error) {
29
+	ip, err := fetchIP(ctx, "tcp6")
30 30
 	if err != nil || ip.To4() != nil {
31 31
 		return nil, errors.Annotate(err, "Cannot find public ipv6 address")
32 32
 	}
33 33
 	return ip, nil
34 34
 }
35 35
 
36
-func fetchIP(network string) (net.IP, error) {
36
+func fetchIP(ctx context.Context, network string) (net.IP, error) {
37 37
 	dialer := &net.Dialer{FallbackDelay: -1}
38 38
 	client := &http.Client{
39 39
 		Jar:     nil,
@@ -45,7 +45,12 @@ func fetchIP(network string) (net.IP, error) {
45 45
 		},
46 46
 	}
47 47
 
48
-	resp, err := client.Get(ifconfigAddress)
48
+	req, err := http.NewRequest("GET", ifconfigAddress, nil)
49
+	if err != nil {
50
+		return nil, errors.Annotate(err, "Cannot create a request")
51
+	}
52
+
53
+	resp, err := client.Do(req.WithContext(ctx))
49 54
 	if err != nil {
50 55
 		if resp != nil {
51 56
 			io.Copy(ioutil.Discard, resp.Body) // nolint: errcheck

+ 23
- 23
main.go 查看文件

@@ -152,29 +152,29 @@ func main() {
152 152
 
153 153
 	case proxyCommand.FullCommand():
154 154
 		err := config.Init(
155
-			config.ConfigOpt{Option: config.OptionTypeDebug, Value: *proxyDebug},
156
-			config.ConfigOpt{Option: config.OptionTypeVerbose, Value: *proxyVerbose},
157
-			config.ConfigOpt{Option: config.OptionTypeBindIP, Value: *proxyBindIP},
158
-			config.ConfigOpt{Option: config.OptionTypeBindPort, Value: *proxyBindPort},
159
-			config.ConfigOpt{Option: config.OptionTypePublicIPv4, Value: *proxyPublicIPv4},
160
-			config.ConfigOpt{Option: config.OptionTypePublicIPv4Port, Value: *proxyPublicIPv4Port},
161
-			config.ConfigOpt{Option: config.OptionTypePublicIPv6, Value: *proxyPublicIPv6},
162
-			config.ConfigOpt{Option: config.OptionTypePublicIPv6Port, Value: *proxyPublicIPv6Port},
163
-			config.ConfigOpt{Option: config.OptionTypeStatsIP, Value: *proxyStatsIP},
164
-			config.ConfigOpt{Option: config.OptionTypeStatsPort, Value: *proxyStatsPort},
165
-			config.ConfigOpt{Option: config.OptionTypeStatsdIP, Value: *proxyStatsdIP},
166
-			config.ConfigOpt{Option: config.OptionTypeStatsdPort, Value: *proxyStatsdPort},
167
-			config.ConfigOpt{Option: config.OptionTypeStatsdNetwork, Value: *proxyStatsdNetwork},
168
-			config.ConfigOpt{Option: config.OptionTypeStatsdPrefix, Value: *proxyStatsdPrefix},
169
-			config.ConfigOpt{Option: config.OptionTypeStatsdTagsFormat, Value: *proxyStatsdTagsFormat},
170
-			config.ConfigOpt{Option: config.OptionTypeStatsdTags, Value: *proxyStatsdTags},
171
-			config.ConfigOpt{Option: config.OptionTypePrometheusPrefix, Value: *proxyPrometheusPrefix},
172
-			config.ConfigOpt{Option: config.OptionTypeWriteBufferSize, Value: *proxyWriteBufferSize},
173
-			config.ConfigOpt{Option: config.OptionTypeReadBufferSize, Value: *proxyReadBufferSize},
174
-			config.ConfigOpt{Option: config.OptionTypeAntiReplayMaxSize, Value: *proxyAntiReplayMaxSize},
175
-			config.ConfigOpt{Option: config.OptionTypeAntiReplayEvictionTime, Value: *proxyAntiReplayEvictionTime},
176
-			config.ConfigOpt{Option: config.OptionTypeSecret, Value: *proxySecret},
177
-			config.ConfigOpt{Option: config.OptionTypeAdtag, Value: *proxyAdtag},
155
+			config.Opt{Option: config.OptionTypeDebug, Value: *proxyDebug},
156
+			config.Opt{Option: config.OptionTypeVerbose, Value: *proxyVerbose},
157
+			config.Opt{Option: config.OptionTypeBindIP, Value: *proxyBindIP},
158
+			config.Opt{Option: config.OptionTypeBindPort, Value: *proxyBindPort},
159
+			config.Opt{Option: config.OptionTypePublicIPv4, Value: *proxyPublicIPv4},
160
+			config.Opt{Option: config.OptionTypePublicIPv4Port, Value: *proxyPublicIPv4Port},
161
+			config.Opt{Option: config.OptionTypePublicIPv6, Value: *proxyPublicIPv6},
162
+			config.Opt{Option: config.OptionTypePublicIPv6Port, Value: *proxyPublicIPv6Port},
163
+			config.Opt{Option: config.OptionTypeStatsIP, Value: *proxyStatsIP},
164
+			config.Opt{Option: config.OptionTypeStatsPort, Value: *proxyStatsPort},
165
+			config.Opt{Option: config.OptionTypeStatsdIP, Value: *proxyStatsdIP},
166
+			config.Opt{Option: config.OptionTypeStatsdPort, Value: *proxyStatsdPort},
167
+			config.Opt{Option: config.OptionTypeStatsdNetwork, Value: *proxyStatsdNetwork},
168
+			config.Opt{Option: config.OptionTypeStatsdPrefix, Value: *proxyStatsdPrefix},
169
+			config.Opt{Option: config.OptionTypeStatsdTagsFormat, Value: *proxyStatsdTagsFormat},
170
+			config.Opt{Option: config.OptionTypeStatsdTags, Value: *proxyStatsdTags},
171
+			config.Opt{Option: config.OptionTypePrometheusPrefix, Value: *proxyPrometheusPrefix},
172
+			config.Opt{Option: config.OptionTypeWriteBufferSize, Value: *proxyWriteBufferSize},
173
+			config.Opt{Option: config.OptionTypeReadBufferSize, Value: *proxyReadBufferSize},
174
+			config.Opt{Option: config.OptionTypeAntiReplayMaxSize, Value: *proxyAntiReplayMaxSize},
175
+			config.Opt{Option: config.OptionTypeAntiReplayEvictionTime, Value: *proxyAntiReplayEvictionTime},
176
+			config.Opt{Option: config.OptionTypeSecret, Value: *proxySecret},
177
+			config.Opt{Option: config.OptionTypeAdtag, Value: *proxyAdtag},
178 178
 		)
179 179
 		if err != nil {
180 180
 			cli.Fatal(err.Error())

+ 11
- 3
proxy/proxy.go 查看文件

@@ -20,17 +20,25 @@ const directPipeBufferSize = 1024 * 1024
20 20
 
21 21
 type Proxy struct {
22 22
 	Logger                *zap.SugaredLogger
23
+	Context               context.Context
23 24
 	ClientProtocolMaker   protocol.ClientProtocolMaker
24 25
 	TelegramProtocolMaker protocol.TelegramProtocolMaker
25 26
 	TelegramDialer        telegram.Telegram
26 27
 }
27 28
 
28 29
 func (p *Proxy) Serve(listener net.Listener) {
30
+	doneChan := p.Context.Done()
31
+
29 32
 	for {
30 33
 		conn, err := listener.Accept()
31 34
 		if err != nil {
32
-			p.Logger.Errorw("Cannot allocate incoming connection", "error", err)
33
-			continue
35
+			select {
36
+			case <-doneChan:
37
+				return
38
+			default:
39
+				p.Logger.Errorw("Cannot allocate incoming connection", "error", err)
40
+				continue
41
+			}
34 42
 		}
35 43
 		go p.accept(conn)
36 44
 	}
@@ -53,7 +61,7 @@ func (p *Proxy) accept(conn net.Conn) {
53 61
 		return
54 62
 	}
55 63
 
56
-	ctx, cancel := context.WithCancel(context.Background())
64
+	ctx, cancel := context.WithCancel(p.Context)
57 65
 	defer cancel()
58 66
 
59 67
 	wrappedConn := wrappers.NewClientConn(ctx, cancel, conn, connID)

+ 6
- 1
stats/stats.go 查看文件

@@ -1,6 +1,7 @@
1 1
 package stats
2 2
 
3 3
 import (
4
+	"context"
4 5
 	"net"
5 6
 	"net/http"
6 7
 
@@ -59,7 +60,7 @@ func (m multiStats) AntiReplayDetected() {
59 60
 
60 61
 var S Stats
61 62
 
62
-func Init() error {
63
+func Init(ctx context.Context) error {
63 64
 	mux := http.NewServeMux()
64 65
 
65 66
 	instanceJSON := newStatsJSON(mux)
@@ -86,6 +87,10 @@ func Init() error {
86 87
 		Handler: mux,
87 88
 	}
88 89
 	go srv.Serve(listener) // nolint: errcheck
90
+	go func() {
91
+		<-ctx.Done()
92
+		srv.Shutdown(context.Background()) // nolint: errcheck
93
+	}()
89 94
 
90 95
 	S = multiStats(stats)
91 96
 

+ 24
- 0
utils/signal_context.go 查看文件

@@ -0,0 +1,24 @@
1
+// +build !windows
2
+
3
+package utils
4
+
5
+import (
6
+	"context"
7
+	"os"
8
+	"os/signal"
9
+	"syscall"
10
+)
11
+
12
+func GetSignalContext() context.Context {
13
+	ctx, cancel := context.WithCancel(context.Background())
14
+	sigChan := make(chan os.Signal, 1)
15
+
16
+	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
17
+	go func() {
18
+		for range sigChan {
19
+			cancel()
20
+		}
21
+	}()
22
+
23
+	return ctx
24
+}

+ 23
- 0
utils/signal_context_windows.go 查看文件

@@ -0,0 +1,23 @@
1
+// +build windows
2
+
3
+package utils
4
+
5
+import (
6
+	"context"
7
+	"os"
8
+	"os/signal"
9
+)
10
+
11
+func GetSignalContext() context.Context {
12
+	ctx, cancel := context.WithCancel(context.Background())
13
+	sigChan := make(chan os.Signal, 1)
14
+
15
+	signal.Notify(sigChan, os.Interrupt)
16
+	go func() {
17
+		for range sigChan {
18
+			cancel()
19
+		}
20
+	}()
21
+
22
+	return ctx
23
+}

Loading…
取消
儲存