1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-14 11:47:33 +00:00

Subscribe to more than one topic

This commit is contained in:
Philipp Heckel 2021-11-15 07:56:58 -05:00
parent a481f4c448
commit 52136030be

View file

@ -78,9 +78,9 @@ const (
var ( var (
topicRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! topicRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
jsonRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/json$`) jsonRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
sseRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/sse$`) sseRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
rawRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/raw$`) rawRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
staticRegex = regexp.MustCompile(`^/static/.+`) staticRegex = regexp.MustCompile(`^/static/.+`)
@ -223,7 +223,7 @@ func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
} }
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
t, err := s.topic(r.URL.Path[1:]) t, err := s.topicFromID(r.URL.Path[1:])
if err != nil { if err != nil {
return err return err
} }
@ -289,7 +289,9 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
return errHTTPTooManyRequests return errHTTPTooManyRequests
} }
defer v.RemoveSubscription() defer v.RemoveSubscription()
t, err := s.topic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack
topicIDs := strings.Split(topicsStr, ",")
topics, err := s.topicsFromIDs(topicIDs...)
if err != nil { if err != nil {
return err return err
} }
@ -314,14 +316,21 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
return s.sendOldMessages(t, since, sub) return s.sendOldMessages(topics, since, sub)
} }
subscriberID := t.Subscribe(sub) subscriberIDs := make([]int, 0)
defer t.Unsubscribe(subscriberID) for _, t := range topics {
if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
}
defer func() {
for i, subscriberID := range subscriberIDs {
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(t, since, sub); err != nil { if err := s.sendOldMessages(topics, since, sub); err != nil {
return err return err
} }
for { for {
@ -330,17 +339,18 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
v.Keepalive() v.Keepalive()
if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err return err
} }
} }
} }
} }
func (s *Server) sendOldMessages(t *topic, since sinceTime, sub subscriber) error { func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscriber) error {
if since.IsNone() { if since.IsNone() {
return nil return nil
} }
for _, t := range topics {
messages, err := s.cache.Messages(t.id, since) messages, err := s.cache.Messages(t.id, since)
if err != nil { if err != nil {
return err return err
@ -350,6 +360,7 @@ func (s *Server) sendOldMessages(t *topic, since sinceTime, sub subscriber) erro
return err return err
} }
} }
}
return nil return nil
} }
@ -382,9 +393,19 @@ func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error {
return nil return nil
} }
func (s *Server) topic(id string) (*topic, error) { func (s *Server) topicFromID(id string) (*topic, error) {
topics, err := s.topicsFromIDs(id)
if err != nil {
return nil, err
}
return topics[0], nil
}
func (s *Server) topicsFromIDs(ids... string) ([]*topic, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
topics := make([]*topic, 0)
for _, id := range ids {
if _, ok := s.topics[id]; !ok { if _, ok := s.topics[id]; !ok {
if len(s.topics) >= s.config.GlobalTopicLimit { if len(s.topics) >= s.config.GlobalTopicLimit {
return nil, errHTTPTooManyRequests return nil, errHTTPTooManyRequests
@ -394,7 +415,9 @@ func (s *Server) topic(id string) (*topic, error) {
s.topics[id].Subscribe(s.firebase) s.topics[id].Subscribe(s.firebase)
} }
} }
return s.topics[id], nil topics = append(topics, s.topics[id])
}
return topics, nil
} }
func (s *Server) updateStatsAndExpire() { func (s *Server) updateStatsAndExpire() {