Ver código fonte

reworked hub

tags/1.0^2
9seconds 6 anos atrás
pai
commit
413cafeeb6
7 arquivos alterados com 181 adições e 140 exclusões
  1. 40
    37
      hub/connection.go
  2. 54
    49
      hub/connection_hub.go
  3. 0
    8
      hub/connection_hub_request.go
  4. 11
    10
      hub/ctx_channel.go
  5. 38
    26
      hub/hub.go
  6. 30
    0
      hub/init.go
  7. 8
    10
      hub/registry.go

+ 40
- 37
hub/connection.go Ver arquivo

@@ -10,21 +10,31 @@ import (
10 10
 	"github.com/9seconds/mtg/protocol"
11 11
 )
12 12
 
13
-type connectionID int
14
-
15 13
 type connection struct {
16
-	conn    conntypes.PacketReadWriteCloser
17
-	mutex   sync.RWMutex
18
-	id      connectionID
19
-	hub     *connectionHub
20
-	pending uint
21
-	closing bool
14
+	conn         conntypes.PacketReadWriteCloser
15
+	mutex        sync.RWMutex
16
+	shutdownOnce sync.Once
17
+	hub          *connectionHub
18
+	id           int
19
+	pending      uint
20
+	done         chan struct{}
22 21
 }
23 22
 
24
-func (c *connection) Write(packet conntypes.Packet) error {
23
+func (c *connection) read() (conntypes.Packet, error) {
24
+	packet, err := c.conn.Read()
25
+
25 26
 	c.mutex.Lock()
26
-	defer c.mutex.Unlock()
27
+	if err != nil {
28
+		c.pending--
29
+	} else {
30
+		c.pending = 0
31
+	}
32
+	c.mutex.Unlock()
33
+
34
+	return packet, err
35
+}
27 36
 
37
+func (c *connection) write(packet conntypes.Packet) error {
28 38
 	err := c.conn.Write(packet)
29 39
 	if err != nil {
30 40
 		// if we tried to write into a socket and it was broken, it is
@@ -32,58 +42,51 @@ func (c *connection) Write(packet conntypes.Packet) error {
32 42
 		//
33 43
 		// probably we need to remove it completely because it seems
34 44
 		// that connection is broken.
45
+		c.mutex.Lock()
35 46
 		c.pending = 0
47
+		c.mutex.Unlock()
36 48
 	}
37 49
 	return err
38 50
 }
39 51
 
40
-func (c *connection) Read() (conntypes.Packet, error) {
41
-	packet, err := c.conn.Read()
52
+func (c *connection) shutdown() {
53
+	c.shutdownOnce.Do(func() {
54
+		close(c.done)
55
+			c.hub.channelBrokenSockets <- c.id
56
+	})
57
+}
42 58
 
43
-	c.mutex.Lock()
44
-	if err != nil {
45
-		c.pending--
46
-	} else {
47
-		c.pending = 0
59
+func (c *connection) closed() bool {
60
+	select {
61
+	case <-c.done:
62
+		return true
63
+	default:
64
+		return false
48 65
 	}
49
-	c.mutex.Unlock()
50
-
51
-	return packet, err
52 66
 }
53 67
 
54
-func (c *connection) Stats() (bool, uint) {
68
+func (c *connection) idle() bool {
55 69
 	c.mutex.RLock()
56 70
 	defer c.mutex.RUnlock()
57 71
 
58
-	return c.closing, c.pending
59
-}
60
-
61
-func (c *connection) Close() error {
62
-	c.mutex.Lock()
63
-	defer c.mutex.Unlock()
64
-
65
-	c.closing = true
66
-	return c.conn.Close()
72
+	return c.pending == 0
67 73
 }
68 74
 
69 75
 func (c *connection) run() {
70 76
 	for {
71
-		packet, err := c.conn.Read()
77
+		packet, err := c.read()
72 78
 		if err != nil {
73
-			c.Close()
74
-			c.hub.brokenSocketsChan <- c.id
75
-			c.hub = nil
79
+			c.shutdown()
76 80
 			return
77 81
 		}
78 82
 
79
-		// TODO
80 83
 		if channel, ok := Registry.getChannel(conntypes.ConnID{}); ok {
81 84
 			go channel.write(packet) // nolint: errcheck
82 85
 		}
83 86
 	}
84 87
 }
85 88
 
86
-func newConnection(hub *connectionHub, req *protocol.TelegramRequest) (*connection, error) {
89
+func newConnection(req *protocol.TelegramRequest, hub *connectionHub) (*connection, error) {
87 90
 	conn, err := mtproto.TelegramProtocol(req)
88 91
 	if err != nil {
89 92
 		return nil, fmt.Errorf("cannot create a new connection: %w", err)
@@ -92,7 +95,7 @@ func newConnection(hub *connectionHub, req *protocol.TelegramRequest) (*connecti
92 95
 	rv := &connection{
93 96
 		conn: conn,
94 97
 		hub:  hub,
95
-		id:   connectionID(rand.Int()),
98
+		id:   rand.Int(),
96 99
 	}
97 100
 	go rv.run()
98 101
 

+ 54
- 49
hub/connection_hub.go Ver arquivo

@@ -1,84 +1,89 @@
1 1
 package hub
2 2
 
3
-import "time"
3
+import (
4
+	"time"
5
+
6
+	"github.com/9seconds/mtg/protocol"
7
+)
4 8
 
5 9
 const hubGCEvery = time.Minute
6 10
 
11
+type connectionHubRequest struct {
12
+	request  *protocol.TelegramRequest
13
+	response chan<- *connection
14
+}
15
+
7 16
 type connectionHub struct {
8
-	sockets map[connectionID]*connection
17
+	sockets map[int]*connection
9 18
 
10
-	brokenSocketsChan      chan connectionID
11
-	connectionRequestsChan chan *connectionHubRequest
12
-	returnConnectionsChan  chan *connection
19
+	channelBrokenSockets      chan int
20
+	channelConnectionRequests chan *connectionHubRequest
21
+	channelReturnConnections  chan *connection
13 22
 }
14 23
 
15
-func (h *connectionHub) run() {
16
-	gcTicker := time.NewTicker(hubGCEvery)
17
-	defer gcTicker.Stop()
24
+func (c *connectionHub) run() {
25
+	ticker := time.NewTicker(hubGCEvery)
26
+	defer ticker.Stop()
18 27
 
19 28
 	for {
20 29
 		select {
21
-		case <-gcTicker.C:
22
-			h.runGC()
23
-		case id := <-h.brokenSocketsChan:
24
-			h.runBrokenConnection(id)
25
-		case request := <-h.connectionRequestsChan:
26
-			h.runConnectionRequest(request)
27
-		case conn := <-h.returnConnectionsChan:
28
-			h.runReturnConnection(conn)
30
+		case <-ticker.C:
31
+			c.runGC()
32
+		case request := <-c.channelConnectionRequests:
33
+			c.runConnectionRequest(request)
34
+		case id := <-c.channelBrokenSockets:
35
+			c.runBrokenSocket(id)
36
+		case conn := <-c.channelReturnConnections:
37
+			c.runReturnConnection(conn)
29 38
 		}
30 39
 	}
31 40
 }
32 41
 
33
-func (h *connectionHub) runBrokenConnection(id connectionID) {
34
-	delete(h.sockets, id)
35
-}
36
-
37
-func (h *connectionHub) runGC() {
38
-	for key, conn := range h.sockets {
39
-		closing, pending := conn.Stats()
42
+func (c *connectionHub) runGC() {
43
+	for key, conn := range c.sockets {
40 44
 		switch {
41
-		case closing:
42
-			delete(h.sockets, key)
43
-		case pending == 0:
44
-			conn.Close()
45
-			delete(h.sockets, key)
45
+		case conn.closed():
46
+			delete(c.sockets, key)
47
+		case conn.idle():
48
+			conn.shutdown()
49
+			delete(c.sockets, key)
46 50
 			return
47 51
 		}
48
-
49 52
 	}
50 53
 }
51 54
 
52
-func (h *connectionHub) runConnectionRequest(req *connectionHubRequest) {
53
-	for key, conn := range h.sockets {
54
-		closing, _ := conn.Stats()
55
-		delete(h.sockets, key)
56
-
57
-		if !closing {
58
-			req.responseChan <- conn
55
+func (c *connectionHub) runConnectionRequest(req *connectionHubRequest) {
56
+	for key, conn := range c.sockets {
57
+		delete(c.sockets, key)
58
+		if !conn.closed() {
59
+			req.response <- conn
60
+			close(req.response)
59 61
 			return
60 62
 		}
61 63
 	}
62 64
 
63
-	newConn, err := newConnection(h, req.req)
64
-	if err != nil {
65
-		close(req.responseChan)
66
-		return
65
+	if conn, err := newConnection(req.request, c); err == nil {
66
+		req.response <- conn
67 67
 	}
68
+	close(req.response)
69
+}
68 70
 
69
-	req.responseChan <- newConn
71
+func (c *connectionHub) runBrokenSocket(id int) {
72
+	delete(c.sockets, id)
70 73
 }
71 74
 
72
-func (h *connectionHub) runReturnConnection(conn *connection) {
73
-	h.sockets[conn.id] = conn
75
+func (c *connectionHub) runReturnConnection(conn *connection) {
76
+	c.sockets[conn.id] = conn
74 77
 }
75 78
 
76 79
 func newConnectionHub() *connectionHub {
77
-	return &connectionHub{
78
-		sockets: map[connectionID]*connection{},
79
-
80
-		brokenSocketsChan:      make(chan connectionID, 1),
81
-		connectionRequestsChan: make(chan *connectionHubRequest),
82
-		returnConnectionsChan:  make(chan *connection, 1),
80
+	rv := &connectionHub{
81
+		sockets:                   map[int]*connection{},
82
+		channelBrokenSockets:      make(chan int, 1),
83
+		channelConnectionRequests: make(chan *connectionHubRequest),
84
+		channelReturnConnections:  make(chan *connection, 1),
83 85
 	}
86
+	go rv.run()
87
+
88
+	return rv
84 89
 }

+ 0
- 8
hub/connection_hub_request.go Ver arquivo

@@ -1,8 +0,0 @@
1
-package hub
2
-
3
-import "github.com/9seconds/mtg/protocol"
4
-
5
-type connectionHubRequest struct {
6
-	req          *protocol.TelegramRequest
7
-	responseChan chan<- *connection
8
-}

hub/closeable_channel.go → hub/ctx_channel.go Ver arquivo

@@ -12,46 +12,47 @@ const closeableChannelReadTimeout = 2 * time.Minute
12 12
 
13 13
 type ChannelReadCloser interface {
14 14
 	Read() (conntypes.Packet, error)
15
-	Close()
15
+	Close() error
16 16
 }
17 17
 
18
-type closeableChannel struct {
18
+type ctxChannel struct {
19 19
 	channel chan conntypes.Packet
20 20
 	ctx     context.Context
21 21
 	cancel  context.CancelFunc
22 22
 }
23 23
 
24
-func (c *closeableChannel) Read() (conntypes.Packet, error) {
24
+func (c *ctxChannel) Read() (conntypes.Packet, error) {
25 25
 	timer := time.NewTimer(closeableChannelReadTimeout)
26 26
 	defer timer.Stop()
27 27
 
28 28
 	select {
29 29
 	case <-timer.C:
30
-		return nil, errors.New("timeout")
30
+		return nil, ErrTimeout
31 31
 	case <-c.ctx.Done():
32
-		return nil, errors.New("channel was closed")
32
+		return nil, ErrClosed
33 33
 	case packet := <-c.channel:
34 34
 		return packet, nil
35 35
 	}
36 36
 }
37 37
 
38
-func (c *closeableChannel) write(packet conntypes.Packet) error {
38
+func (c *ctxChannel) write(packet conntypes.Packet) error {
39 39
 	select {
40 40
 	case <-c.ctx.Done():
41
-		return errors.New("channel was closed")
41
+		return ErrClosed
42 42
 	case c.channel <- packet:
43 43
 		return nil
44 44
 	}
45 45
 }
46 46
 
47
-func (c *closeableChannel) Close() {
47
+func (c *ctxChannel) Close() error {
48 48
 	c.cancel()
49 49
 	c.channel = nil
50
+	return nil
50 51
 }
51 52
 
52
-func newCloseableChannel(ctx context.Context) *closeableChannel {
53
+func newCtxChannel(ctx context.Context) *ctxChannel {
53 54
 	ctx, cancel := context.WithCancel(ctx)
54
-	return &closeableChannel{
55
+	return &ctxChannel{
55 56
 		channel: make(chan conntypes.Packet),
56 57
 		ctx:     ctx,
57 58
 		cancel:  cancel,

+ 38
- 26
hub/hub.go Ver arquivo

@@ -1,48 +1,60 @@
1 1
 package hub
2 2
 
3 3
 import (
4
-	"errors"
4
+	"encoding/binary"
5
+	"fmt"
6
+	"strings"
5 7
 	"sync"
6 8
 
7 9
 	"github.com/9seconds/mtg/conntypes"
8 10
 	"github.com/9seconds/mtg/protocol"
9 11
 )
10 12
 
11
-type Concentrator struct {
12
-	hubs sync.Map
13
+type hub struct {
14
+	subs  map[string]*connectionHub
15
+	mutex sync.RWMutex
13 16
 }
14 17
 
15
-func (c *Concentrator) Write(packet conntypes.Packet, req *protocol.TelegramRequest) error {
16
-	hub := c.getHub(req)
17
-	connectionChan := make(chan *connection)
18
-	hub.connectionRequestsChan <- &connectionHubRequest{
19
-		req:          req,
20
-		responseChan: connectionChan,
18
+func (h *hub) Write(packet conntypes.Packet, req *protocol.TelegramRequest) error {
19
+	sub := h.getHub(req)
20
+	connections := make(chan *connection)
21
+	sub.channelConnectionRequests <- &connectionHubRequest{
22
+		request:  req,
23
+		response: connections,
21 24
 	}
22 25
 
23
-	conn, ok := <-connectionChan
26
+	conn, ok := <-connections
24 27
 	if !ok {
25
-		return errors.New("cannot establish connection to telegram")
28
+		return ErrCannotCreateConnection
26 29
 	}
27
-}
28 30
 
29
-func (c *Concentrator) getHub(req *protocol.TelegramRequest) *connectionHub {
30
-	dcMapRaw, ok := c.hubs.Load(req.ClientProtocol.DC())
31
-	if !ok {
32
-		dcMapRaw, _ = c.hubs.LoadOrStore(req.ClientProtocol.DC(), &sync.Map{})
31
+	if err := conn.write(packet); err != nil {
32
+		return fmt.Errorf("cannot send packet: %w", err)
33 33
 	}
34
-	dcMap := dcMapRaw.(*sync.Map)
34
+	return nil
35
+}
36
+
37
+func (h *hub) getHub(req *protocol.TelegramRequest) *connectionHub {
38
+	keyBuilder := strings.Builder{}
39
+	binary.Write(&keyBuilder, binary.LittleEndian, int16(req.ClientProtocol.DC()))
40
+	keyBuilder.WriteRune('_')
41
+	binary.Write(&keyBuilder, binary.LittleEndian, uint8(req.ClientProtocol.ConnectionProtocol()))
42
+	key := keyBuilder.String()
43
+
44
+	h.mutex.RLock()
45
+	rv, ok := h.subs[key]
46
+	h.mutex.RUnlock()
35 47
 
36
-	loaded := true
37
-	hubRaw, ok := dcMap.Load(req.ClientProtocol.ConnectionProtocol())
38 48
 	if !ok {
39
-		hubRaw, loaded = dcMap.LoadOrStore(req.ClientProtocol.ConnectionProtocol(),
40
-			newConnectionHub())
41
-	}
42
-	hub := hubRaw.(*connectionHub)
43
-	if !loaded {
44
-		go hub.run()
49
+		h.mutex.Lock()
50
+		defer h.mutex.Unlock()
51
+
52
+		rv, ok = h.subs[key]
53
+		if !ok {
54
+			rv = newConnectionHub()
55
+			h.subs[key] = rv
56
+		}
45 57
 	}
46 58
 
47
-	return hub
59
+	return rv
48 60
 }

+ 30
- 0
hub/init.go Ver arquivo

@@ -0,0 +1,30 @@
1
+package hub
2
+
3
+import (
4
+	"context"
5
+	"errors"
6
+	"sync"
7
+)
8
+
9
+var (
10
+	Registry *registry
11
+	Hub      *hub
12
+
13
+	ErrTimeout                = errors.New("timeout")
14
+	ErrClosed                 = errors.New("channel was closed")
15
+	ErrCannotCreateConnection = errors.New("cannot create connection")
16
+
17
+	initOnce sync.Once
18
+)
19
+
20
+func Init(ctx context.Context) {
21
+	initOnce.Do(func() {
22
+		Registry = &registry{
23
+			conns: map[string]*ctxChannel{},
24
+			ctx:   ctx,
25
+		}
26
+		Hub = &hub{
27
+			subs: map[string]*connectionHub{},
28
+		}
29
+	})
30
+}

+ 8
- 10
hub/registry.go Ver arquivo

@@ -7,16 +7,14 @@ import (
7 7
 	"github.com/9seconds/mtg/conntypes"
8 8
 )
9 9
 
10
-var Registry *RegistryStruct
11
-
12
-type RegistryStruct struct {
13
-	conns map[string]*closeableChannel
10
+type registry struct {
11
+	conns map[string]*ctxChannel
14 12
 	ctx   context.Context
15 13
 	mutex sync.RWMutex
16 14
 }
17 15
 
18
-func (r *RegistryStruct) Register(id conntypes.ConnID) ChannelReadCloser {
19
-	channel := newCloseableChannel(r.ctx)
16
+func (r *registry) Register(id conntypes.ConnID) ChannelReadCloser {
17
+	channel := newCtxChannel(r.ctx)
20 18
 
21 19
 	r.mutex.Lock()
22 20
 	r.conns[string(id[:])] = channel
@@ -25,7 +23,7 @@ func (r *RegistryStruct) Register(id conntypes.ConnID) ChannelReadCloser {
25 23
 	return channel
26 24
 }
27 25
 
28
-func (r *RegistryStruct) Unregister(id conntypes.ConnID) {
26
+func (r *registry) Unregister(id conntypes.ConnID) {
29 27
 	r.mutex.Lock()
30 28
 	defer r.mutex.Unlock()
31 29
 
@@ -35,7 +33,7 @@ func (r *RegistryStruct) Unregister(id conntypes.ConnID) {
35 33
 	}
36 34
 }
37 35
 
38
-func (r *RegistryStruct) getChannel(id conntypes.ConnID) (*closeableChannel, bool) {
36
+func (r *registry) getChannel(id conntypes.ConnID) (*ctxChannel, bool) {
39 37
 	r.mutex.RLock()
40 38
 	defer r.mutex.RUnlock()
41 39
 
@@ -46,8 +44,8 @@ func (r *RegistryStruct) getChannel(id conntypes.ConnID) (*closeableChannel, boo
46 44
 }
47 45
 
48 46
 func InitRegistry(ctx context.Context) {
49
-	Registry = &RegistryStruct{
47
+	Registry = &registry{
50 48
 		ctx:   ctx,
51
-		conns: map[string]*closeableChannel{},
49
+		conns: map[string]*ctxChannel{},
52 50
 	}
53 51
 }

Carregando…
Cancelar
Salvar