mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-15 17:50:55 +00:00
SQLite cache
This commit is contained in:
parent
1c7695c1f3
commit
7b810acfb5
5 changed files with 254 additions and 133 deletions
|
@ -1,61 +1,14 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
createTableQuery = `CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
)`
|
||||
insertQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`
|
||||
pruneOlderThanQuery = `DELETE FROM messages WHERE time < ?`
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
db *sql.DB
|
||||
insert *sql.Stmt
|
||||
prune *sql.Stmt
|
||||
}
|
||||
|
||||
func newCache(filename string) (*cache, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec(createTableQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
insert, err := db.Prepare(insertQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prune, err := db.Prepare(pruneOlderThanQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cache{
|
||||
db: db,
|
||||
insert: insert,
|
||||
prune: prune,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *cache) Load() (map[string]*topic, error) {
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) Add(m *message) error {
|
||||
_, err := c.insert.Exec(m.ID, m.Time, m.Topic, m.Message)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cache) Prune(olderThan time.Duration) error {
|
||||
_, err := c.prune.Exec(time.Now().Add(-1 * olderThan).Unix())
|
||||
return err
|
||||
type cache interface {
|
||||
AddMessage(m *message) error
|
||||
Messages(topic string, since time.Time) ([]*message, error)
|
||||
MessageCount(topic string) (int, error)
|
||||
Topics() (map[string]*topic, error)
|
||||
Prune(keep time.Duration) error
|
||||
}
|
||||
|
|
80
server/cache_mem.go
Normal file
80
server/cache_mem.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type memCache struct {
|
||||
messages map[string][]*message
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var _ cache = (*memCache)(nil)
|
||||
|
||||
func newMemCache() *memCache {
|
||||
return &memCache{
|
||||
messages: make(map[string][]*message),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memCache) AddMessage(m *message) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.messages[m.Topic]; !ok {
|
||||
s.messages[m.Topic] = make([]*message, 0)
|
||||
}
|
||||
s.messages[m.Topic] = append(s.messages[m.Topic], m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memCache) Messages(topic string, since time.Time) ([]*message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.messages[topic]; !ok {
|
||||
return make([]*message, 0), nil
|
||||
}
|
||||
messages := make([]*message, 0) // copy!
|
||||
for _, m := range s.messages[topic] {
|
||||
msgTime := time.Unix(m.Time, 0)
|
||||
if msgTime == since || msgTime.After(since) {
|
||||
messages = append(messages, m)
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *memCache) MessageCount(topic string) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.messages[topic]; !ok {
|
||||
return 0, nil
|
||||
}
|
||||
return len(s.messages[topic]), nil
|
||||
}
|
||||
|
||||
func (s *memCache) Topics() (map[string]*topic, error) {
|
||||
// Hack since we know when this is called there are no messages!
|
||||
return make(map[string]*topic), nil
|
||||
}
|
||||
|
||||
func (s *memCache) Prune(keep time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for topic, _ := range s.messages {
|
||||
s.pruneTopic(topic, keep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memCache) pruneTopic(topic string, keep time.Duration) {
|
||||
for i, m := range s.messages[topic] {
|
||||
msgTime := time.Unix(m.Time, 0)
|
||||
if time.Since(msgTime) < keep {
|
||||
s.messages[topic] = s.messages[topic][i:]
|
||||
return
|
||||
}
|
||||
}
|
||||
s.messages[topic] = make([]*message, 0) // all messages expired
|
||||
}
|
127
server/cache_sqlite.go
Normal file
127
server/cache_sqlite.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
createTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
COMMIT;
|
||||
`
|
||||
insertMessageQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`
|
||||
pruneMessagesQuery = `DELETE FROM messages WHERE time < ?`
|
||||
selectMessagesSinceTimeQuery = `
|
||||
SELECT id, time, message
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time ASC
|
||||
`
|
||||
selectMessageCountQuery = `SELECT count(*) FROM messages WHERE topic = ?`
|
||||
selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY TOPIC`
|
||||
)
|
||||
|
||||
type sqliteCache struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
var _ cache = (*sqliteCache)(nil)
|
||||
|
||||
func newSqliteCache(filename string) (*sqliteCache, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec(createTableQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqliteCache{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *sqliteCache) AddMessage(m *message) error {
|
||||
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *sqliteCache) Messages(topic string, since time.Time) ([]*message, error) {
|
||||
rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
messages := make([]*message, 0)
|
||||
for rows.Next() {
|
||||
var timestamp int64
|
||||
var id, msg string
|
||||
if err := rows.Scan(&id, ×tamp, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, &message{
|
||||
ID: id,
|
||||
Time: timestamp,
|
||||
Event: messageEvent,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *sqliteCache) MessageCount(topic string) (int, error) {
|
||||
rows, err := c.db.Query(selectMessageCountQuery, topic)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var count int
|
||||
if !rows.Next() {
|
||||
return 0, errors.New("no rows found")
|
||||
}
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *sqliteCache) Topics() (map[string]*topic, error) {
|
||||
rows, err := s.db.Query(selectTopicsQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
topics := make(map[string]*topic, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var last int64
|
||||
if err := rows.Scan(&id, &last); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics[id] = newTopic(id, time.Unix(last, 0))
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
func (c *sqliteCache) Prune(keep time.Duration) error {
|
||||
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1 * keep).Unix())
|
||||
return err
|
||||
}
|
|
@ -32,7 +32,7 @@ type Server struct {
|
|||
visitors map[string]*visitor
|
||||
firebase subscriber
|
||||
messages int64
|
||||
cache *cache
|
||||
cache cache
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -78,30 +78,28 @@ func New(conf *config.Config) (*Server, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
cache, err := maybeCreateCache(conf)
|
||||
cache, err := createCache(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics := make(map[string]*topic)
|
||||
if cache != nil {
|
||||
if topics, err = cache.Load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics, err := cache.Topics()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{
|
||||
config: conf,
|
||||
cache: cache,
|
||||
cache: cache,
|
||||
firebase: firebaseSubscriber,
|
||||
topics: topics,
|
||||
visitors: make(map[string]*visitor),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func maybeCreateCache(conf *config.Config) (*cache, error) {
|
||||
if conf.CacheFile == "" {
|
||||
return nil, nil
|
||||
func createCache(conf *config.Config) (cache, error) {
|
||||
if conf.CacheFile != "" {
|
||||
return newSqliteCache(conf.CacheFile)
|
||||
}
|
||||
return newCache(conf.CacheFile)
|
||||
return newMemCache(), nil
|
||||
}
|
||||
|
||||
func createFirebaseSubscriber(conf *config.Config) (subscriber, error) {
|
||||
|
@ -202,8 +200,8 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
if err := t.Publish(m); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
s.cache.Add(m)
|
||||
if err := s.cache.AddMessage(m); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
s.mu.Lock()
|
||||
|
@ -277,20 +275,18 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
|
|||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
if poll {
|
||||
return sendOldMessages(t, since, sub)
|
||||
return s.sendOldMessages(t, since, sub)
|
||||
}
|
||||
subscriberID := t.Subscribe(sub)
|
||||
defer t.Unsubscribe(subscriberID)
|
||||
if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := sendOldMessages(t, since, sub); err != nil {
|
||||
if err := s.sendOldMessages(t, since, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return nil
|
||||
case <-r.Context().Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
|
@ -302,11 +298,15 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
|
|||
}
|
||||
}
|
||||
|
||||
func sendOldMessages(t *topic, since time.Time, sub subscriber) error {
|
||||
func (s *Server) sendOldMessages(t *topic, since time.Time, sub subscriber) error {
|
||||
if since.IsZero() {
|
||||
return nil
|
||||
}
|
||||
for _, m := range t.Messages(since) {
|
||||
messages, err := s.cache.Messages(t.id, since)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
if err := sub(m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -340,7 +340,7 @@ func (s *Server) topic(id string) (*topic, error) {
|
|||
if len(s.topics) >= s.config.GlobalTopicLimit {
|
||||
return nil, errHTTPTooManyRequests
|
||||
}
|
||||
s.topics[id] = newTopic(id)
|
||||
s.topics[id] = newTopic(id, time.Now())
|
||||
if s.firebase != nil {
|
||||
s.topics[id].Subscribe(s.firebase)
|
||||
}
|
||||
|
@ -360,28 +360,28 @@ func (s *Server) updateStatsAndExpire() {
|
|||
}
|
||||
|
||||
// Prune cache
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil {
|
||||
log.Printf("error pruning cache: %s", err.Error())
|
||||
}
|
||||
if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil {
|
||||
log.Printf("error pruning cache: %s", err.Error())
|
||||
}
|
||||
|
||||
// Prune old messages, remove subscriptions without subscribers
|
||||
for _, t := range s.topics {
|
||||
t.Prune(s.config.MessageBufferDuration)
|
||||
subs, msgs := t.Stats()
|
||||
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
|
||||
delete(s.topics, t.id)
|
||||
}
|
||||
}
|
||||
|
||||
// Print stats
|
||||
var subscribers, messages int
|
||||
for _, t := range s.topics {
|
||||
subs, msgs := t.Stats()
|
||||
subs := t.Subscribers()
|
||||
msgs, err := s.cache.MessageCount(t.id)
|
||||
if err != nil {
|
||||
log.Printf("cannot get stats for topic %s: %s", t.id, err.Error())
|
||||
continue
|
||||
}
|
||||
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
|
||||
delete(s.topics, t.id)
|
||||
continue
|
||||
}
|
||||
subscribers += subs
|
||||
messages += msgs
|
||||
}
|
||||
|
||||
// Print stats
|
||||
log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)",
|
||||
s.messages, len(s.topics), subscribers, messages, len(s.visitors))
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
@ -12,11 +11,8 @@ import (
|
|||
// can publish a message
|
||||
type topic struct {
|
||||
id string
|
||||
subscribers map[int]subscriber
|
||||
messages []*message
|
||||
last time.Time
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
subscribers map[int]subscriber
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -24,15 +20,11 @@ type topic struct {
|
|||
type subscriber func(msg *message) error
|
||||
|
||||
// newTopic creates a new topic
|
||||
func newTopic(id string) *topic {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func newTopic(id string, last time.Time) *topic {
|
||||
return &topic{
|
||||
id: id,
|
||||
last: last,
|
||||
subscribers: make(map[int]subscriber),
|
||||
messages: make([]*message, 0),
|
||||
last: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,7 +47,6 @@ func (t *topic) Publish(m *message) error {
|
|||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.last = time.Now()
|
||||
t.messages = append(t.messages, m)
|
||||
for _, s := range t.subscribers {
|
||||
if err := s(m); err != nil {
|
||||
log.Printf("error publishing message to subscriber")
|
||||
|
@ -64,38 +55,8 @@ func (t *topic) Publish(m *message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *topic) Messages(since time.Time) []*message {
|
||||
func (t *topic) Subscribers() int {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
messages := make([]*message, 0) // copy!
|
||||
for _, m := range t.messages {
|
||||
msgTime := time.Unix(m.Time, 0)
|
||||
if msgTime == since || msgTime.After(since) {
|
||||
messages = append(messages, m)
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func (t *topic) Prune(keep time.Duration) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for i, m := range t.messages {
|
||||
msgTime := time.Unix(m.Time, 0)
|
||||
if time.Since(msgTime) < keep {
|
||||
t.messages = t.messages[i:]
|
||||
return
|
||||
}
|
||||
}
|
||||
t.messages = make([]*message, 0)
|
||||
}
|
||||
|
||||
func (t *topic) Stats() (subscribers int, messages int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return len(t.subscribers), len(t.messages)
|
||||
}
|
||||
|
||||
func (t *topic) Close() {
|
||||
t.cancel()
|
||||
return len(t.subscribers)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue