소스 검색

Add timeoutreadwritecloser

tags/0.9
9seconds 8 년 전
부모
커밋
83b46d1b80
4개의 변경된 파일76개의 추가작업 그리고 21개의 파일을 삭제
  1. 13
    1
      main.go
  2. 24
    17
      server/server.go
  3. 4
    3
      server/telegram.go
  4. 35
    0
      server/timeoutrwc.go

+ 13
- 1
main.go 파일 보기

@@ -39,6 +39,16 @@ var (
39 39
 			Envar("MTG_PORT").
40 40
 			Default("3128").
41 41
 			Uint16()
42
+	readTimeout = app.Flag("read-timeout", "Socket read timeout").
43
+			Short('r').
44
+			Envar("MTG_READ_TIMEOUT").
45
+			Default("30s").
46
+			Duration()
47
+	writeTimeout = app.Flag("write-timeout", "Socket write timeout").
48
+			Short('w').
49
+			Envar("MTG_WRITE_TIMEOUT").
50
+			Default("30s").
51
+			Duration()
42 52
 	serverName = app.Flag("server-name",
43 53
 		"Which server name to use. Default is IP address resolved by ipify.").
44 54
 		Short('s').
@@ -86,7 +96,9 @@ func main() {
86 96
 	)).Sugar()
87 97
 
88 98
 	printURLs()
89
-	if err := server.NewServer(*bindIP, int(*bindPort), secretBytes, logger).Serve(); err != nil {
99
+	srv := server.NewServer(*bindIP, int(*bindPort), secretBytes, logger,
100
+		*readTimeout, *writeTimeout)
101
+	if err := srv.Serve(); err != nil {
90 102
 		logger.Fatal(err.Error())
91 103
 	}
92 104
 }

+ 24
- 17
server/server.go 파일 보기

@@ -6,6 +6,7 @@ import (
6 6
 	"net"
7 7
 	"strconv"
8 8
 	"sync"
9
+	"time"
9 10
 
10 11
 	"github.com/9seconds/mtg/obfuscated2"
11 12
 	"github.com/juju/errors"
@@ -16,12 +17,14 @@ import (
16 17
 const bufferSize = 4096
17 18
 
18 19
 type Server struct {
19
-	ip     net.IP
20
-	port   int
21
-	secret []byte
22
-	logger *zap.SugaredLogger
23
-	lsock  net.Listener
24
-	ctx    context.Context
20
+	ip           net.IP
21
+	port         int
22
+	secret       []byte
23
+	logger       *zap.SugaredLogger
24
+	lsock        net.Listener
25
+	ctx          context.Context
26
+	readTimeout  time.Duration
27
+	writeTimeout time.Duration
25 28
 }
26 29
 
27 30
 func (s *Server) Serve() error {
@@ -98,7 +101,8 @@ func (s *Server) makeSocketID() string {
98 101
 }
99 102
 
100 103
 func (s *Server) getClientStream(conn net.Conn, ctx context.Context, cancel context.CancelFunc, socketID string) (io.ReadWriteCloser, int16, error) {
101
-	frame, err := obfuscated2.ExtractFrame(conn)
104
+	wConn := newTimeoutReadWriteCloser(conn, s.readTimeout, s.writeTimeout)
105
+	frame, err := obfuscated2.ExtractFrame(wConn)
102 106
 	if err != nil {
103 107
 		return nil, 0, errors.Annotate(err, "Cannot create client stream")
104 108
 	}
@@ -108,7 +112,7 @@ func (s *Server) getClientStream(conn net.Conn, ctx context.Context, cancel cont
108 112
 		return nil, 0, errors.Annotate(err, "Cannot create client stream")
109 113
 	}
110 114
 
111
-	wConn := newLogReadWriteCloser(conn, s.logger, socketID, "client")
115
+	wConn = newLogReadWriteCloser(wConn, s.logger, socketID, "client")
112 116
 	wConn = newCipherReadWriteCloser(conn, obfs2)
113 117
 	wConn = newCtxReadWriteCloser(wConn, ctx, cancel)
114 118
 
@@ -116,17 +120,18 @@ func (s *Server) getClientStream(conn net.Conn, ctx context.Context, cancel cont
116 120
 }
117 121
 
118 122
 func (s *Server) getTelegramStream(dc int16, ctx context.Context, cancel context.CancelFunc, socketID string) (io.ReadWriteCloser, error) {
119
-	socket, err := dialToTelegram(dc)
123
+	socket, err := dialToTelegram(dc, s.readTimeout)
120 124
 	if err != nil {
121 125
 		return nil, errors.Annotate(err, "Cannot dial")
122 126
 	}
127
+	wConn := newTimeoutReadWriteCloser(socket, s.readTimeout, s.writeTimeout)
123 128
 
124 129
 	obfs2, frame := obfuscated2.MakeTelegramObfuscated2Frame()
125 130
 	if n, err := socket.Write(frame); err != nil || n != len(frame) {
126 131
 		return nil, errors.Annotate(err, "Cannot write hadnshake frame")
127 132
 	}
128 133
 
129
-	wConn := newLogReadWriteCloser(socket, s.logger, socketID, "telegram")
134
+	wConn = newLogReadWriteCloser(socket, s.logger, socketID, "telegram")
130 135
 	wConn = newCipherReadWriteCloser(wConn, obfs2)
131 136
 	wConn = newCtxReadWriteCloser(wConn, ctx, cancel)
132 137
 
@@ -138,15 +143,17 @@ func (s *Server) pipe(wait *sync.WaitGroup, reader io.Reader, writer io.Writer)
138 143
 
139 144
 	buf := make([]byte, bufferSize)
140 145
 	io.CopyBuffer(writer, reader, buf)
141
-
142 146
 }
143 147
 
144
-func NewServer(ip net.IP, port int, secret []byte, logger *zap.SugaredLogger) *Server {
148
+func NewServer(ip net.IP, port int, secret []byte, logger *zap.SugaredLogger,
149
+	readTimeout, writeTimeout time.Duration) *Server {
145 150
 	return &Server{
146
-		ip:     ip,
147
-		port:   port,
148
-		secret: secret,
149
-		ctx:    context.Background(),
150
-		logger: logger,
151
+		ip:           ip,
152
+		port:         port,
153
+		secret:       secret,
154
+		ctx:          context.Background(),
155
+		logger:       logger,
156
+		readTimeout:  readTimeout,
157
+		writeTimeout: writeTimeout,
151 158
 	}
152 159
 }

+ 4
- 3
server/telegram.go 파일 보기

@@ -17,13 +17,14 @@ var telegramDCIPs = [5]string{
17 17
 
18 18
 const telegramKeepAlive = 30 * time.Second
19 19
 
20
-func dialToTelegram(dcIdx int16) (net.Conn, error) {
20
+func dialToTelegram(dcIdx int16, timeout time.Duration) (net.Conn, error) {
21 21
 	if dcIdx < 0 || dcIdx >= 5 {
22 22
 		return nil, errors.New("Incorrect DC IDX")
23 23
 	}
24 24
 
25
-	tcpAddr, _ := net.ResolveTCPAddr("tcp", telegramDCIPs[dcIdx])
26
-	conn, err := net.DialTCP("tcp", nil, tcpAddr)
25
+	dialer := net.Dialer{Timeout: timeout}
26
+	rawConn, err := dialer.Dial("tcp", telegramDCIPs[dcIdx])
27
+	conn := rawConn.(*net.TCPConn)
27 28
 	if err != nil {
28 29
 		return nil, errors.Annotate(err, "Cannot dial")
29 30
 	}

+ 35
- 0
server/timeoutrwc.go 파일 보기

@@ -0,0 +1,35 @@
1
+package server
2
+
3
+import (
4
+	"io"
5
+	"net"
6
+	"time"
7
+)
8
+
9
+type TimeoutReadWriteCloser struct {
10
+	conn         net.Conn
11
+	readTimeout  time.Duration
12
+	writeTimeout time.Duration
13
+}
14
+
15
+func (t *TimeoutReadWriteCloser) Read(p []byte) (int, error) {
16
+	t.conn.SetReadDeadline(time.Now().Add(t.readTimeout))
17
+	return t.conn.Read(p)
18
+}
19
+
20
+func (t *TimeoutReadWriteCloser) Write(p []byte) (int, error) {
21
+	t.conn.SetWriteDeadline(time.Now().Add(t.writeTimeout))
22
+	return t.conn.Write(p)
23
+}
24
+
25
+func (t *TimeoutReadWriteCloser) Close() error {
26
+	return t.conn.Close()
27
+}
28
+
29
+func newTimeoutReadWriteCloser(conn net.Conn, readTimeout, writeTimeout time.Duration) io.ReadWriteCloser {
30
+	return &TimeoutReadWriteCloser{
31
+		conn:         conn,
32
+		readTimeout:  readTimeout,
33
+		writeTimeout: writeTimeout,
34
+	}
35
+}

Loading…
취소
저장