|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+package wrappers
|
|
|
2
|
+
|
|
|
3
|
+import (
|
|
|
4
|
+ "bytes"
|
|
|
5
|
+ "io"
|
|
|
6
|
+ "io/ioutil"
|
|
|
7
|
+ "net"
|
|
|
8
|
+
|
|
|
9
|
+ "github.com/juju/errors"
|
|
|
10
|
+
|
|
|
11
|
+ "github.com/9seconds/mtg/mtproto"
|
|
|
12
|
+ "github.com/9seconds/mtg/mtproto/rpc"
|
|
|
13
|
+ "github.com/9seconds/mtg/wrappers"
|
|
|
14
|
+)
|
|
|
15
|
+
|
|
|
16
|
+var (
|
|
|
17
|
+ rpcCloseExtTag = [4]byte{0xa2, 0x34, 0xb6, 0x5e}
|
|
|
18
|
+ rpcProxyAnsTag = [4]byte{0x0d, 0xda, 0x03, 0x44}
|
|
|
19
|
+ rpcSimpleAckTag = [4]byte{0x9b, 0x40, 0xac, 0x3b}
|
|
|
20
|
+)
|
|
|
21
|
+
|
|
|
22
|
+type ProxyRequestReadWriteCloserWithAddr struct {
|
|
|
23
|
+ wrappers.BufferedReader
|
|
|
24
|
+
|
|
|
25
|
+ conn wrappers.ReadWriteCloserWithAddr
|
|
|
26
|
+ req *rpc.RPCProxyRequest
|
|
|
27
|
+}
|
|
|
28
|
+
|
|
|
29
|
+func (p *ProxyRequestReadWriteCloserWithAddr) Read(buf []byte) (int, error) {
|
|
|
30
|
+ return p.BufferedRead(buf, func() error {
|
|
|
31
|
+ ansBuf := &bytes.Buffer{}
|
|
|
32
|
+ ansBuf.Grow(4)
|
|
|
33
|
+
|
|
|
34
|
+ if _, err := io.CopyN(ansBuf, p.conn, 4); err != nil {
|
|
|
35
|
+ return errors.Annotate(err, "Cannot read RPC tag")
|
|
|
36
|
+ }
|
|
|
37
|
+
|
|
|
38
|
+ if bytes.Equal(ansBuf.Bytes(), rpcCloseExtTag[:]) {
|
|
|
39
|
+ return errors.New("Connection has been closed remotely")
|
|
|
40
|
+ } else if bytes.Equal(ansBuf.Bytes(), rpcProxyAnsTag[:]) {
|
|
|
41
|
+ if _, err := io.CopyN(ioutil.Discard, p.conn, 8+4); err != nil {
|
|
|
42
|
+ return errors.Annotate(err, "Cannot skip flags and connid")
|
|
|
43
|
+ }
|
|
|
44
|
+ for {
|
|
|
45
|
+ n, err := p.conn.Read(buf)
|
|
|
46
|
+ if err != nil {
|
|
|
47
|
+ return errors.Annotate(err, "Cannot read proxy answer")
|
|
|
48
|
+ }
|
|
|
49
|
+ if n == 0 {
|
|
|
50
|
+ break
|
|
|
51
|
+ }
|
|
|
52
|
+ p.Buffer.Write(buf[:n])
|
|
|
53
|
+ }
|
|
|
54
|
+ return nil
|
|
|
55
|
+ } else if bytes.Equal(ansBuf.Bytes(), rpcSimpleAckTag[:]) {
|
|
|
56
|
+ if _, err := io.CopyN(ioutil.Discard, p.conn, 8); err != nil {
|
|
|
57
|
+ return errors.Annotate(err, "Cannot skip connid")
|
|
|
58
|
+ }
|
|
|
59
|
+ if _, err := io.CopyN(p.Buffer, p.conn, 4); err != nil {
|
|
|
60
|
+ return errors.Annotate(err, "Cannot read simple ack")
|
|
|
61
|
+ }
|
|
|
62
|
+ p.req.Options.SimpleAck = true
|
|
|
63
|
+ return nil
|
|
|
64
|
+ }
|
|
|
65
|
+
|
|
|
66
|
+ return nil
|
|
|
67
|
+ })
|
|
|
68
|
+}
|
|
|
69
|
+
|
|
|
70
|
+func (p *ProxyRequestReadWriteCloserWithAddr) Write(raw []byte) (int, error) {
|
|
|
71
|
+ if _, err := p.conn.Write(p.req.Bytes(raw)); err != nil {
|
|
|
72
|
+ return 0, err
|
|
|
73
|
+ }
|
|
|
74
|
+ p.req.Options.SimpleAck = false
|
|
|
75
|
+ p.req.Options.QuickAck = false
|
|
|
76
|
+
|
|
|
77
|
+ return len(raw), nil
|
|
|
78
|
+}
|
|
|
79
|
+
|
|
|
80
|
+func (p *ProxyRequestReadWriteCloserWithAddr) Close() error {
|
|
|
81
|
+ return p.conn.Close()
|
|
|
82
|
+}
|
|
|
83
|
+
|
|
|
84
|
+func (p *ProxyRequestReadWriteCloserWithAddr) LocalAddr() *net.TCPAddr {
|
|
|
85
|
+ return p.conn.LocalAddr()
|
|
|
86
|
+}
|
|
|
87
|
+
|
|
|
88
|
+func (p *ProxyRequestReadWriteCloserWithAddr) RemoteAddr() *net.TCPAddr {
|
|
|
89
|
+ return p.conn.RemoteAddr()
|
|
|
90
|
+}
|
|
|
91
|
+
|
|
|
92
|
+func NewProxyRequestRWC(conn wrappers.ReadWriteCloserWithAddr, connOpts *mtproto.ConnectionOpts, adTag []byte) (wrappers.ReadWriteCloserWithAddr, error) {
|
|
|
93
|
+ req, err := rpc.NewRPCProxyRequest(connOpts.ClientAddr, conn.LocalAddr(), connOpts, adTag)
|
|
|
94
|
+ if err != nil {
|
|
|
95
|
+ return nil, errors.Annotate(err, "Cannot create new RPC proxy request")
|
|
|
96
|
+ }
|
|
|
97
|
+
|
|
|
98
|
+ return &ProxyRequestReadWriteCloserWithAddr{
|
|
|
99
|
+ conn: conn,
|
|
|
100
|
+ req: req,
|
|
|
101
|
+ }, nil
|
|
|
102
|
+}
|