|
|
@@ -4,7 +4,6 @@ import (
|
|
4
|
4
|
"bytes"
|
|
5
|
5
|
"crypto/aes"
|
|
6
|
6
|
"crypto/cipher"
|
|
7
|
|
- "errors"
|
|
8
|
7
|
"fmt"
|
|
9
|
8
|
"net"
|
|
10
|
9
|
"time"
|
|
|
@@ -12,10 +11,9 @@ import (
|
|
12
|
11
|
"go.uber.org/zap"
|
|
13
|
12
|
|
|
14
|
13
|
"github.com/9seconds/mtg/conntypes"
|
|
|
14
|
+ "github.com/9seconds/mtg/utils"
|
|
15
|
15
|
)
|
|
16
|
16
|
|
|
17
|
|
-const blockCipherReadCurrentDataBufferSize = 1024 + 1 // +1 because telegram operates with blocks mod 4
|
|
18
|
|
-
|
|
19
|
17
|
type wrapperBlockCipher struct {
|
|
20
|
18
|
buf bytes.Buffer
|
|
21
|
19
|
|
|
|
@@ -41,35 +39,29 @@ func (w *wrapperBlockCipher) WriteTimeout(p []byte, timeout time.Duration) (int,
|
|
41
|
39
|
}
|
|
42
|
40
|
|
|
43
|
41
|
func (w *wrapperBlockCipher) Read(p []byte) (int, error) {
|
|
44
|
|
- return w.read(p, readAll)
|
|
45
|
|
-
|
|
46
|
|
-}
|
|
47
|
|
-
|
|
48
|
|
-func (w *wrapperBlockCipher) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
|
|
49
|
|
- return w.read(p, readAllTimeout(timeout))
|
|
50
|
|
-}
|
|
51
|
|
-
|
|
52
|
|
-func (w *wrapperBlockCipher) read(p []byte,
|
|
53
|
|
- reader func(conntypes.StreamReadWriteCloser) ([]byte, error)) (int, error) {
|
|
54
|
42
|
if w.buf.Len() > 0 {
|
|
55
|
43
|
return w.flush(p)
|
|
56
|
44
|
}
|
|
57
|
45
|
|
|
58
|
|
- var buf []byte
|
|
59
|
|
- for len(buf) == 0 || len(buf)%aes.BlockSize != 0 {
|
|
60
|
|
- rv, err := reader(w.parent)
|
|
|
46
|
+ var currentBuffer []byte
|
|
|
47
|
+ for len(currentBuffer) == 0 || len(currentBuffer)%aes.BlockSize != 0 {
|
|
|
48
|
+ rv, err := utils.ReadFull(w.parent)
|
|
61
|
49
|
if err != nil {
|
|
62
|
|
- return 0, fmt.Errorf("cannot read from socket: %w", err)
|
|
|
50
|
+ return 0, fmt.Errorf("cannot read data: %w", err)
|
|
63
|
51
|
}
|
|
64
|
|
- buf = append(buf, rv...)
|
|
|
52
|
+ currentBuffer = append(currentBuffer, rv...)
|
|
65
|
53
|
}
|
|
66
|
54
|
|
|
67
|
|
- w.decryptor.CryptBlocks(buf, buf)
|
|
68
|
|
- w.buf.Write(buf)
|
|
|
55
|
+ w.decryptor.CryptBlocks(currentBuffer, currentBuffer)
|
|
|
56
|
+ w.buf.Write(currentBuffer)
|
|
69
|
57
|
|
|
70
|
58
|
return w.flush(p)
|
|
71
|
59
|
}
|
|
72
|
60
|
|
|
|
61
|
+func (w *wrapperBlockCipher) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
|
|
|
62
|
+ return w.Read(p)
|
|
|
63
|
+}
|
|
|
64
|
+
|
|
73
|
65
|
func (w *wrapperBlockCipher) flush(p []byte) (int, error) {
|
|
74
|
66
|
if w.buf.Len() > len(p) {
|
|
75
|
67
|
return w.buf.Read(p)
|
|
|
@@ -93,44 +85,6 @@ func (w *wrapperBlockCipher) encrypt(p []byte) ([]byte, error) {
|
|
93
|
85
|
return encrypted, nil
|
|
94
|
86
|
}
|
|
95
|
87
|
|
|
96
|
|
-func readAll(src conntypes.StreamReadWriteCloser) (rv []byte, err error) {
|
|
97
|
|
- buf := make([]byte, blockCipherReadCurrentDataBufferSize)
|
|
98
|
|
- n := blockCipherReadCurrentDataBufferSize
|
|
99
|
|
-
|
|
100
|
|
- for n == len(buf) {
|
|
101
|
|
- n, err = src.Read(buf)
|
|
102
|
|
- if err != nil {
|
|
103
|
|
- return nil, err
|
|
104
|
|
- }
|
|
105
|
|
- rv = append(rv, buf[:n]...)
|
|
106
|
|
- }
|
|
107
|
|
-
|
|
108
|
|
- return rv, nil
|
|
109
|
|
-}
|
|
110
|
|
-
|
|
111
|
|
-func readAllTimeout(timeout time.Duration) func(conntypes.StreamReadWriteCloser) ([]byte, error) {
|
|
112
|
|
- return func(src conntypes.StreamReadWriteCloser) (rv []byte, err error) {
|
|
113
|
|
- tmo := timeout
|
|
114
|
|
- buf := make([]byte, blockCipherReadCurrentDataBufferSize)
|
|
115
|
|
- n := blockCipherReadCurrentDataBufferSize
|
|
116
|
|
-
|
|
117
|
|
- for n == len(buf) {
|
|
118
|
|
- if tmo <= 0 {
|
|
119
|
|
- return nil, errors.New("timeout")
|
|
120
|
|
- }
|
|
121
|
|
- startTime := time.Now()
|
|
122
|
|
- n, err = src.ReadTimeout(buf, tmo)
|
|
123
|
|
- if err != nil {
|
|
124
|
|
- return nil, err
|
|
125
|
|
- }
|
|
126
|
|
- rv = append(rv, buf[:n]...)
|
|
127
|
|
- tmo -= time.Since(startTime)
|
|
128
|
|
- }
|
|
129
|
|
-
|
|
130
|
|
- return rv, nil
|
|
131
|
|
- }
|
|
132
|
|
-}
|
|
133
|
|
-
|
|
134
|
88
|
func (w *wrapperBlockCipher) Close() error {
|
|
135
|
89
|
return w.parent.Close()
|
|
136
|
90
|
}
|