| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- package wrappers
-
- import (
- "bytes"
- "io"
- "net"
-
- "github.com/juju/errors"
- "go.uber.org/zap"
-
- "github.com/9seconds/mtg/mtproto"
- "github.com/9seconds/mtg/utils"
- )
-
- const (
- mtprotoAbridgedSmallPacketLength = 0x7f
- mtprotoAbridgedQuickAckLength = 0x80
- mtprotoAbridgedLargePacketLength = 16777216 // 256 ^ 3
- )
-
- // MTProtoAbridged presents abridged connection between client and
- // middle proxy.
- type MTProtoAbridged struct {
- conn StreamReadWriteCloser
- opts *mtproto.ConnectionOpts
- logger *zap.SugaredLogger
-
- readCounter uint32
- writeCounter uint32
- }
-
- func (m *MTProtoAbridged) Read() ([]byte, error) {
- defer func() {
- m.readCounter++
- }()
-
- m.logger.Debugw("Read packet",
- "simple_ack", m.opts.ReadHacks.SimpleAck,
- "quick_ack", m.opts.ReadHacks.QuickAck,
- "counter", m.readCounter,
- )
-
- buf := &bytes.Buffer{}
- buf.Grow(3)
-
- if _, err := io.CopyN(buf, m.conn, 1); err != nil {
- return nil, errors.Annotate(err, "Cannot read message length")
- }
- msgLength := uint32(buf.Bytes()[0])
- buf.Reset()
-
- m.logger.Debugw("Packet first byte",
- "byte", msgLength,
- "counter", m.readCounter,
- "simple_ack", m.opts.ReadHacks.SimpleAck,
- "quick_ack", m.opts.ReadHacks.QuickAck,
- )
-
- if msgLength >= mtprotoAbridgedQuickAckLength {
- m.opts.ReadHacks.QuickAck = true
- msgLength -= mtprotoAbridgedQuickAckLength
- }
-
- if msgLength == mtprotoAbridgedSmallPacketLength {
- if _, err := io.CopyN(buf, m.conn, 3); err != nil {
- return nil, errors.Annotate(err, "Cannot read the correct message length")
- }
- number := utils.Uint24{}
- copy(number[:], buf.Bytes())
- msgLength = utils.FromUint24(number)
- }
- msgLength *= 4
-
- m.logger.Debugw("Packet length",
- "length", msgLength,
- "simple_ack", m.opts.ReadHacks.SimpleAck,
- "quick_ack", m.opts.ReadHacks.QuickAck,
- "counter", m.readCounter,
- )
-
- buf.Reset()
- buf.Grow(int(msgLength))
- if _, err := io.CopyN(buf, m.conn, int64(msgLength)); err != nil {
- return nil, errors.Annotate(err, "Cannot read message")
- }
-
- return buf.Bytes(), nil
- }
-
- func (m *MTProtoAbridged) Write(p []byte) (int, error) {
- defer func() {
- m.writeCounter++
- }()
-
- m.logger.Debugw("Write packet",
- "length", len(p),
- "simple_ack", m.opts.WriteHacks.SimpleAck,
- "quick_ack", m.opts.WriteHacks.QuickAck,
- "counter", m.writeCounter,
- )
-
- if len(p)%4 != 0 {
- return 0, errors.Errorf("Incorrect packet length %d", len(p))
- }
-
- if m.opts.WriteHacks.SimpleAck {
- return m.conn.Write(utils.ReverseBytes(p))
- }
-
- packetLength := len(p) / 4
- switch {
- case packetLength < mtprotoAbridgedSmallPacketLength:
- newData := append([]byte{byte(packetLength)}, p...)
- return m.conn.Write(newData)
-
- case packetLength < mtprotoAbridgedLargePacketLength:
- length24 := utils.ToUint24(uint32(packetLength))
-
- buf := &bytes.Buffer{}
- buf.Grow(1 + 3 + len(p))
-
- buf.WriteByte(byte(mtprotoAbridgedSmallPacketLength)) // nolint: gosec
- buf.Write(length24[:]) // nolint: gosec
- buf.Write(p) // nolint: gosec
-
- return m.conn.Write(buf.Bytes())
- }
-
- return 0, errors.Errorf("Packet is too big %d", len(p))
- }
-
- // Logger returns an instance of the logger for this wrapper.
- func (m *MTProtoAbridged) Logger() *zap.SugaredLogger {
- return m.logger
- }
-
- // LocalAddr returns local address of the underlying net.Conn.
- func (m *MTProtoAbridged) LocalAddr() *net.TCPAddr {
- return m.conn.LocalAddr()
- }
-
- // RemoteAddr returns remote address of the underlying net.Conn.
- func (m *MTProtoAbridged) RemoteAddr() *net.TCPAddr {
- return m.conn.RemoteAddr()
- }
-
- // Close closes underlying net.Conn instance.
- func (m *MTProtoAbridged) Close() error {
- return m.conn.Close()
- }
-
- // NewMTProtoAbridged creates new wrapper for abridged client connection.
- func NewMTProtoAbridged(conn StreamReadWriteCloser, opts *mtproto.ConnectionOpts) PacketReadWriteCloser {
- return &MTProtoAbridged{
- conn: conn,
- opts: opts,
- logger: conn.Logger().Named("mtproto-abridged"),
- }
- }
|