| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- package wrappers
-
- import (
- "context"
- "crypto/rand"
- "encoding/hex"
- "net"
- "time"
-
- "github.com/juju/errors"
- "go.uber.org/zap"
-
- "github.com/9seconds/mtg/config"
- )
-
- const ConnIDLength = 8
-
- type ConnID [ConnIDLength]byte
-
- func (c ConnID) String() string {
- return hex.EncodeToString(c[:])
- }
-
- type connPurpose uint8
-
- const (
- connPurposeClient connPurpose = 1 << iota
- connPurposeTelegram
- )
-
- const (
- connTimeoutRead = 2 * time.Minute
- connTimeoutWrite = 2 * time.Minute
- )
-
- type wrapperConn struct {
- parent net.Conn
- ctx context.Context
- cancel context.CancelFunc
- connID ConnID
- logger *zap.SugaredLogger
- localAddr *net.TCPAddr
- remoteAddr *net.TCPAddr
- }
-
- func (w *wrapperConn) WriteTimeout(p []byte, timeout time.Duration) (int, error) {
- select {
- case <-w.ctx.Done():
- w.Close()
- return 0, errors.Annotate(w.ctx.Err(), "Cannot write because context was closed")
-
- default:
- if err := w.parent.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
- w.Close() // nolint: gosec
- return 0, errors.Annotate(err, "Cannot set write deadline to the socket")
- }
-
- n, err := w.parent.Write(p)
- w.logger.Debugw("Write to stream", "bytes", n, "error", err)
- if err != nil {
- w.Close() // nolint: gosec
- }
-
- return n, err
- }
- }
-
- func (w *wrapperConn) Write(p []byte) (int, error) {
- return w.WriteTimeout(p, connTimeoutWrite)
- }
-
- func (w *wrapperConn) ReadTimeout(p []byte, timeout time.Duration) (int, error) {
- select {
- case <-w.ctx.Done():
- w.Close()
- return 0, errors.Annotate(w.ctx.Err(), "Cannot read because context was closed")
-
- default:
- if err := w.parent.SetReadDeadline(time.Now().Add(timeout)); err != nil {
- w.Close()
- return 0, errors.Annotate(err, "Cannot set read deadline to the socket")
- }
-
- n, err := w.parent.Read(p)
- w.logger.Debugw("Read from stream", "bytes", n, "error", err)
- if err != nil {
- w.Close()
- }
-
- return n, err
- }
- }
-
- func (w *wrapperConn) Read(p []byte) (int, error) {
- return w.ReadTimeout(p, connTimeoutRead)
- }
-
- func (w *wrapperConn) Close() error {
- w.logger.Debugw("Close connection")
- w.cancel()
-
- return w.parent.Close()
- }
-
- func (w *wrapperConn) Conn() net.Conn {
- return w.parent
- }
-
- func (w *wrapperConn) Logger() *zap.SugaredLogger {
- return w.logger
- }
-
- func (w *wrapperConn) LocalAddr() *net.TCPAddr {
- return w.localAddr
- }
-
- func (w *wrapperConn) RemoteAddr() *net.TCPAddr {
- return w.remoteAddr
- }
-
- func newConn(ctx context.Context,
- cancel context.CancelFunc,
- parent net.Conn,
- connID ConnID,
- purpose connPurpose) StreamReadWriteCloser {
- localAddr := *parent.LocalAddr().(*net.TCPAddr)
-
- if parent.RemoteAddr().(*net.TCPAddr).IP.To4() != nil {
- if config.C.PublicIPv4Addr.IP != nil {
- localAddr.IP = config.C.PublicIPv4Addr.IP
- }
- } else if config.C.PublicIPv6Addr.IP != nil {
- localAddr.IP = config.C.PublicIPv6Addr.IP
- }
-
- logger := zap.S().With(
- "local_address", localAddr,
- "remote_address", parent.RemoteAddr(),
- ).Named("conn")
-
- if purpose == connPurposeClient {
- logger = logger.With("connection_id", connID.String())
- }
-
- return &wrapperConn{
- parent: parent,
- ctx: ctx,
- cancel: cancel,
- connID: connID,
- logger: logger,
- remoteAddr: parent.RemoteAddr().(*net.TCPAddr),
- localAddr: &localAddr,
- }
- }
-
- func NewClientConn(ctx context.Context,
- cancel context.CancelFunc,
- parent net.Conn,
- connID ConnID) StreamReadWriteCloser {
- return newConn(ctx, cancel, parent, connID, connPurposeClient)
- }
-
- func NewTelegramConn(ctx context.Context,
- cancel context.CancelFunc,
- parent net.Conn) StreamReadWriteCloser {
- return newConn(ctx, cancel, parent, ConnID{}, connPurposeTelegram)
- }
-
- func NewConnID() ConnID {
- var id ConnID
-
- if _, err := rand.Read(id[:]); err != nil {
- panic(err)
- }
-
- return id
- }
|