mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-14 11:47:33 +00:00
Fix tests, lock topic as short as possible
This commit is contained in:
parent
85f2252a77
commit
c1f7bed8d1
3 changed files with 33 additions and 20 deletions
|
@ -34,9 +34,9 @@ func testCacheMessages(t *testing.T, c *messageCache) {
|
|||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// mytopic: count
|
||||
count, err := c.MessageCounts("mytopic")
|
||||
counts, err := c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
||||
|
@ -66,18 +66,18 @@ func testCacheMessages(t *testing.T, c *messageCache) {
|
|||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: count
|
||||
count, err = c.MessageCounts("example")
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
require.Equal(t, 1, counts["example"])
|
||||
|
||||
// example: since all
|
||||
messages, _ = c.Messages("example", sinceAllMessages, false)
|
||||
require.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: count
|
||||
count, err = c.MessageCounts("doesnotexist")
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
require.Equal(t, 0, counts["doesnotexist"])
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
|
||||
|
@ -255,13 +255,13 @@ func testCachePrune(t *testing.T, c *messageCache) {
|
|||
require.Nil(t, c.AddMessage(m3))
|
||||
require.Nil(t, c.Prune(time.Unix(2, 0)))
|
||||
|
||||
count, err := c.MessageCounts("mytopic")
|
||||
counts, err := c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
require.Equal(t, 1, counts["mytopic"])
|
||||
|
||||
count, err = c.MessageCounts("another_topic")
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
require.Equal(t, 0, counts["another_topic"])
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
|
|
|
@ -1090,7 +1090,7 @@ func (s *Server) updateStatsAndPrune() {
|
|||
staleVisitors := 0
|
||||
for ip, v := range s.visitors {
|
||||
if v.Stale() {
|
||||
log.Debug("Deleting stale visitor %s", v.ip)
|
||||
log.Trace("Deleting stale visitor %s", v.ip)
|
||||
delete(s.visitors, ip)
|
||||
staleVisitors++
|
||||
}
|
||||
|
@ -1131,13 +1131,14 @@ func (s *Server) updateStatsAndPrune() {
|
|||
messages += count
|
||||
}
|
||||
|
||||
// Prune old topics, remove subscriptions without subscribers
|
||||
// Remove subscriptions without subscribers
|
||||
s.mu.Lock()
|
||||
var subscribers int
|
||||
for _, t := range s.topics {
|
||||
subs := t.Subscribers()
|
||||
subs := t.SubscribersCount()
|
||||
msgs, exists := messageCounts[t.ID]
|
||||
if subs == 0 && (!exists || msgs == 0) {
|
||||
log.Trace("Deleting empty topic %s", t.ID)
|
||||
delete(s.topics, t.ID)
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -44,11 +44,12 @@ func (t *topic) Unsubscribe(id int) {
|
|||
// Publish asynchronously publishes to all subscribers
|
||||
func (t *topic) Publish(v *visitor, m *message) error {
|
||||
go func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if len(t.subscribers) > 0 {
|
||||
log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(t.subscribers))
|
||||
for _, s := range t.subscribers {
|
||||
// We want to lock the topic as short as possible, so we make a shallow copy of the
|
||||
// subscribers map here. Actually sending out the messages then doesn't have to lock.
|
||||
subscribers := t.subscribersCopy()
|
||||
if len(subscribers) > 0 {
|
||||
log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers))
|
||||
for _, s := range subscribers {
|
||||
if err := s(v, m); err != nil {
|
||||
log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
|
||||
}
|
||||
|
@ -60,9 +61,20 @@ func (t *topic) Publish(v *visitor, m *message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Subscribers returns the number of subscribers to this topic
|
||||
func (t *topic) Subscribers() int {
|
||||
// SubscribersCount returns the number of subscribers to this topic
|
||||
func (t *topic) SubscribersCount() int {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return len(t.subscribers)
|
||||
}
|
||||
|
||||
// subscribersCopy returns a shallow copy of the subscribers map
|
||||
func (t *topic) subscribersCopy() map[int]subscriber {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscribers := make(map[int]subscriber)
|
||||
for k, v := range t.subscribers {
|
||||
subscribers[k] = v
|
||||
}
|
||||
return subscribers
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue