| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- package hub
-
- import (
- "fmt"
- "math/rand"
- "sync"
-
- "go.uber.org/zap"
-
- "mtg/conntypes"
- "mtg/mtproto"
- "mtg/mtproto/rpc"
- "mtg/protocol"
- )
-
- type connection struct {
- conn conntypes.PacketReadWriteCloser
- proxyConns map[string]*ProxyConn
- closeOnce sync.Once
- proxyConnsMutex sync.RWMutex
- id int
- logger *zap.SugaredLogger
-
- channelDone chan struct{}
- channelWrite chan conntypes.Packet
- channelRead chan *rpc.ProxyResponse
- channelConnAttach chan *ProxyConn
- channelConnDetach chan conntypes.ConnID
- }
-
- func (c *connection) run() {
- defer c.Close()
-
- for {
- select {
- case <-c.channelDone:
- for _, v := range c.proxyConns {
- v.Close()
- }
-
- return
- case resp := <-c.channelRead:
- if channel, ok := c.proxyConns[string(resp.ConnID[:])]; ok {
- if resp.Type == rpc.ProxyResponseTypeCloseExt {
- channel.Close()
- } else {
- channel.put(resp)
- }
- }
- case packet := <-c.channelWrite:
- if err := c.conn.Write(packet); err != nil {
- c.logger.Debugw("Cannot write packet", "error", err)
- c.Close()
- }
- case conn := <-c.channelConnAttach:
- c.proxyConnsMutex.Lock()
- c.proxyConns[string(conn.req.ConnID[:])] = conn
- c.proxyConnsMutex.Unlock()
- conn.channelWrite = c.channelWrite
- case connID := <-c.channelConnDetach:
- if conn, ok := c.proxyConns[string(connID[:])]; ok {
- c.proxyConnsMutex.Lock()
- delete(c.proxyConns, string(connID[:]))
- c.proxyConnsMutex.Unlock()
- conn.Close()
- }
- }
- }
- }
-
- func (c *connection) readLoop() {
- for {
- packet, err := c.conn.Read()
- if err != nil {
- c.logger.Debugw("Cannot read packet", "error", err)
- c.Close()
-
- return
- }
-
- response, err := rpc.ParseProxyResponse(packet)
- if err != nil {
- c.logger.Debugw("Failed response", "error", err)
- continue
- }
-
- select {
- case <-c.channelDone:
- return
- case c.channelRead <- response:
- }
- }
- }
-
- func (c *connection) Close() {
- c.closeOnce.Do(func() {
- c.logger.Debugw("Closing connection")
-
- close(c.channelDone)
- c.conn.Close()
- })
- }
-
- func (c *connection) Done() bool {
- select {
- case <-c.channelDone:
- return true
- default:
- return c.Len() == 0
- }
- }
-
- func (c *connection) Len() int {
- c.proxyConnsMutex.RLock()
- defer c.proxyConnsMutex.RUnlock()
-
- return len(c.proxyConns)
- }
-
- func (c *connection) Attach(conn *ProxyConn) error {
- select {
- case <-c.channelDone:
- return ErrClosed
- case c.channelConnAttach <- conn:
- return nil
- }
- }
-
- func (c *connection) Detach(connID conntypes.ConnID) {
- select {
- case <-c.channelDone:
- case c.channelConnDetach <- connID:
- }
- }
-
- func newConnection(req *protocol.TelegramRequest) (*connection, error) {
- conn, err := mtproto.TelegramProtocol(req)
- if err != nil {
- return nil, fmt.Errorf("cannot create a new connection: %w", err)
- }
-
- id := rand.Int() // nolint: gosec
- rv := &connection{
- conn: conn,
- id: id,
- logger: zap.S().Named("hub-connection").With("id", id,
- "dc", req.ClientProtocol.DC(),
- "protocol", req.ClientProtocol.ConnectionProtocol()),
- proxyConns: make(map[string]*ProxyConn),
-
- channelRead: make(chan *rpc.ProxyResponse, 1),
- channelDone: make(chan struct{}),
- channelWrite: make(chan conntypes.Packet),
- channelConnAttach: make(chan *ProxyConn),
- channelConnDetach: make(chan conntypes.ConnID),
- }
-
- go rv.readLoop()
-
- go rv.run()
-
- return rv, nil
- }
|