Просмотр исходного кода

Add timeoutreadwritecloser

tags/0.9
9seconds 8 лет назад
Родитель
Сommit
83b46d1b80
4 измененных файлов: 76 добавлений и 21 удалений
  1. 13
    1
      main.go
  2. 24
    17
      server/server.go
  3. 4
    3
      server/telegram.go
  4. 35
    0
      server/timeoutrwc.go

+ 13
- 1
main.go Просмотреть файл

39
 			Envar("MTG_PORT").
39
 			Envar("MTG_PORT").
40
 			Default("3128").
40
 			Default("3128").
41
 			Uint16()
41
 			Uint16()
42
+	readTimeout = app.Flag("read-timeout", "Socket read timeout").
43
+			Short('r').
44
+			Envar("MTG_READ_TIMEOUT").
45
+			Default("30s").
46
+			Duration()
47
+	writeTimeout = app.Flag("write-timeout", "Socket write timeout").
48
+			Short('w').
49
+			Envar("MTG_WRITE_TIMEOUT").
50
+			Default("30s").
51
+			Duration()
42
 	serverName = app.Flag("server-name",
52
 	serverName = app.Flag("server-name",
43
 		"Which server name to use. Default is IP address resolved by ipify.").
53
 		"Which server name to use. Default is IP address resolved by ipify.").
44
 		Short('s').
54
 		Short('s').
86
 	)).Sugar()
96
 	)).Sugar()
87
 
97
 
88
 	printURLs()
98
 	printURLs()
89
-	if err := server.NewServer(*bindIP, int(*bindPort), secretBytes, logger).Serve(); err != nil {
99
+	srv := server.NewServer(*bindIP, int(*bindPort), secretBytes, logger,
100
+		*readTimeout, *writeTimeout)
101
+	if err := srv.Serve(); err != nil {
90
 		logger.Fatal(err.Error())
102
 		logger.Fatal(err.Error())
91
 	}
103
 	}
92
 }
104
 }

+ 24
- 17
server/server.go Просмотреть файл

6
 	"net"
6
 	"net"
7
 	"strconv"
7
 	"strconv"
8
 	"sync"
8
 	"sync"
9
+	"time"
9
 
10
 
10
 	"github.com/9seconds/mtg/obfuscated2"
11
 	"github.com/9seconds/mtg/obfuscated2"
11
 	"github.com/juju/errors"
12
 	"github.com/juju/errors"
16
 const bufferSize = 4096
17
 const bufferSize = 4096
17
 
18
 
18
 type Server struct {
19
 type Server struct {
19
-	ip     net.IP
20
-	port   int
21
-	secret []byte
22
-	logger *zap.SugaredLogger
23
-	lsock  net.Listener
24
-	ctx    context.Context
20
+	ip           net.IP
21
+	port         int
22
+	secret       []byte
23
+	logger       *zap.SugaredLogger
24
+	lsock        net.Listener
25
+	ctx          context.Context
26
+	readTimeout  time.Duration
27
+	writeTimeout time.Duration
25
 }
28
 }
26
 
29
 
27
 func (s *Server) Serve() error {
30
 func (s *Server) Serve() error {
98
 }
101
 }
99
 
102
 
100
 func (s *Server) getClientStream(conn net.Conn, ctx context.Context, cancel context.CancelFunc, socketID string) (io.ReadWriteCloser, int16, error) {
103
 func (s *Server) getClientStream(conn net.Conn, ctx context.Context, cancel context.CancelFunc, socketID string) (io.ReadWriteCloser, int16, error) {
101
-	frame, err := obfuscated2.ExtractFrame(conn)
104
+	wConn := newTimeoutReadWriteCloser(conn, s.readTimeout, s.writeTimeout)
105
+	frame, err := obfuscated2.ExtractFrame(wConn)
102
 	if err != nil {
106
 	if err != nil {
103
 		return nil, 0, errors.Annotate(err, "Cannot create client stream")
107
 		return nil, 0, errors.Annotate(err, "Cannot create client stream")
104
 	}
108
 	}
108
 		return nil, 0, errors.Annotate(err, "Cannot create client stream")
112
 		return nil, 0, errors.Annotate(err, "Cannot create client stream")
109
 	}
113
 	}
110
 
114
 
111
-	wConn := newLogReadWriteCloser(conn, s.logger, socketID, "client")
115
+	wConn = newLogReadWriteCloser(wConn, s.logger, socketID, "client")
112
 	wConn = newCipherReadWriteCloser(conn, obfs2)
116
 	wConn = newCipherReadWriteCloser(conn, obfs2)
113
 	wConn = newCtxReadWriteCloser(wConn, ctx, cancel)
117
 	wConn = newCtxReadWriteCloser(wConn, ctx, cancel)
114
 
118
 
116
 }
120
 }
117
 
121
 
118
 func (s *Server) getTelegramStream(dc int16, ctx context.Context, cancel context.CancelFunc, socketID string) (io.ReadWriteCloser, error) {
122
 func (s *Server) getTelegramStream(dc int16, ctx context.Context, cancel context.CancelFunc, socketID string) (io.ReadWriteCloser, error) {
119
-	socket, err := dialToTelegram(dc)
123
+	socket, err := dialToTelegram(dc, s.readTimeout)
120
 	if err != nil {
124
 	if err != nil {
121
 		return nil, errors.Annotate(err, "Cannot dial")
125
 		return nil, errors.Annotate(err, "Cannot dial")
122
 	}
126
 	}
127
+	wConn := newTimeoutReadWriteCloser(socket, s.readTimeout, s.writeTimeout)
123
 
128
 
124
 	obfs2, frame := obfuscated2.MakeTelegramObfuscated2Frame()
129
 	obfs2, frame := obfuscated2.MakeTelegramObfuscated2Frame()
125
 	if n, err := socket.Write(frame); err != nil || n != len(frame) {
130
 	if n, err := socket.Write(frame); err != nil || n != len(frame) {
126
 		return nil, errors.Annotate(err, "Cannot write hadnshake frame")
131
 		return nil, errors.Annotate(err, "Cannot write hadnshake frame")
127
 	}
132
 	}
128
 
133
 
129
-	wConn := newLogReadWriteCloser(socket, s.logger, socketID, "telegram")
134
+	wConn = newLogReadWriteCloser(socket, s.logger, socketID, "telegram")
130
 	wConn = newCipherReadWriteCloser(wConn, obfs2)
135
 	wConn = newCipherReadWriteCloser(wConn, obfs2)
131
 	wConn = newCtxReadWriteCloser(wConn, ctx, cancel)
136
 	wConn = newCtxReadWriteCloser(wConn, ctx, cancel)
132
 
137
 
138
 
143
 
139
 	buf := make([]byte, bufferSize)
144
 	buf := make([]byte, bufferSize)
140
 	io.CopyBuffer(writer, reader, buf)
145
 	io.CopyBuffer(writer, reader, buf)
141
-
142
 }
146
 }
143
 
147
 
144
-func NewServer(ip net.IP, port int, secret []byte, logger *zap.SugaredLogger) *Server {
148
+func NewServer(ip net.IP, port int, secret []byte, logger *zap.SugaredLogger,
149
+	readTimeout, writeTimeout time.Duration) *Server {
145
 	return &Server{
150
 	return &Server{
146
-		ip:     ip,
147
-		port:   port,
148
-		secret: secret,
149
-		ctx:    context.Background(),
150
-		logger: logger,
151
+		ip:           ip,
152
+		port:         port,
153
+		secret:       secret,
154
+		ctx:          context.Background(),
155
+		logger:       logger,
156
+		readTimeout:  readTimeout,
157
+		writeTimeout: writeTimeout,
151
 	}
158
 	}
152
 }
159
 }

+ 4
- 3
server/telegram.go Просмотреть файл

17
 
17
 
18
 const telegramKeepAlive = 30 * time.Second
18
 const telegramKeepAlive = 30 * time.Second
19
 
19
 
20
-func dialToTelegram(dcIdx int16) (net.Conn, error) {
20
+func dialToTelegram(dcIdx int16, timeout time.Duration) (net.Conn, error) {
21
 	if dcIdx < 0 || dcIdx >= 5 {
21
 	if dcIdx < 0 || dcIdx >= 5 {
22
 		return nil, errors.New("Incorrect DC IDX")
22
 		return nil, errors.New("Incorrect DC IDX")
23
 	}
23
 	}
24
 
24
 
25
-	tcpAddr, _ := net.ResolveTCPAddr("tcp", telegramDCIPs[dcIdx])
26
-	conn, err := net.DialTCP("tcp", nil, tcpAddr)
25
+	dialer := net.Dialer{Timeout: timeout}
26
+	rawConn, err := dialer.Dial("tcp", telegramDCIPs[dcIdx])
27
+	conn := rawConn.(*net.TCPConn)
27
 	if err != nil {
28
 	if err != nil {
28
 		return nil, errors.Annotate(err, "Cannot dial")
29
 		return nil, errors.Annotate(err, "Cannot dial")
29
 	}
30
 	}

+ 35
- 0
server/timeoutrwc.go Просмотреть файл

1
+package server
2
+
3
+import (
4
+	"io"
5
+	"net"
6
+	"time"
7
+)
8
+
9
+type TimeoutReadWriteCloser struct {
10
+	conn         net.Conn
11
+	readTimeout  time.Duration
12
+	writeTimeout time.Duration
13
+}
14
+
15
+func (t *TimeoutReadWriteCloser) Read(p []byte) (int, error) {
16
+	t.conn.SetReadDeadline(time.Now().Add(t.readTimeout))
17
+	return t.conn.Read(p)
18
+}
19
+
20
+func (t *TimeoutReadWriteCloser) Write(p []byte) (int, error) {
21
+	t.conn.SetWriteDeadline(time.Now().Add(t.writeTimeout))
22
+	return t.conn.Write(p)
23
+}
24
+
25
+func (t *TimeoutReadWriteCloser) Close() error {
26
+	return t.conn.Close()
27
+}
28
+
29
+func newTimeoutReadWriteCloser(conn net.Conn, readTimeout, writeTimeout time.Duration) io.ReadWriteCloser {
30
+	return &TimeoutReadWriteCloser{
31
+		conn:         conn,
32
+		readTimeout:  readTimeout,
33
+		writeTimeout: writeTimeout,
34
+	}
35
+}

Загрузка…
Отмена
Сохранить