1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-15 17:50:55 +00:00

Fix rate limiting behind proxy, make configurable

This commit is contained in:
Philipp Heckel 2021-11-05 13:46:27 -04:00
parent 86a16e3944
commit 0170f673bd
5 changed files with 99 additions and 45 deletions

View file

@ -22,8 +22,13 @@ func New() *cli.App {
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: config.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: config.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: config.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: config.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
}
return &cli.App{
Name: "ntfy",
@ -50,6 +55,11 @@ func execRun(c *cli.Context) error {
cacheDuration := c.Duration("cache-duration")
keepaliveInterval := c.Duration("keepalive-interval")
managerInterval := c.Duration("manager-interval")
globalTopicLimit := c.Int("global-topic-limit")
visitorSubscriptionLimit := c.Int("visitor-subscription-limit")
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
behindProxy := c.Bool("behind-proxy")
// Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
@ -69,6 +79,11 @@ func execRun(c *cli.Context) error {
conf.CacheDuration = cacheDuration
conf.KeepaliveInterval = keepaliveInterval
conf.ManagerInterval = managerInterval
conf.GlobalTopicLimit = globalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.BehindProxy = behindProxy
s, err := server.New(conf)
if err != nil {
log.Fatalln(err)

View file

@ -2,7 +2,6 @@
package config
import (
"golang.org/x/time/rate"
"time"
)
@ -15,14 +14,14 @@ const (
)
// Defines all the limits
// - request limit: max number of PUT/GET/.. requests (here: 50 requests bucket, replenished at a rate of one per 10 seconds)
// - global topic limit: max number of topics overall
// - subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
var (
defaultGlobalTopicLimit = 5000
defaultVisitorRequestLimit = rate.Every(10 * time.Second)
defaultVisitorRequestLimitBurst = 60
defaultVisitorSubscriptionLimit = 30
// - per visistor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
// - per visistor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
const (
DefaultGlobalTopicLimit = 5000
DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 10 * time.Second
DefaultVisitorSubscriptionLimit = 30
)
// Config is the main config struct for the application. Use New to instantiate a default config struct.
@ -34,9 +33,10 @@ type Config struct {
KeepaliveInterval time.Duration
ManagerInterval time.Duration
GlobalTopicLimit int
VisitorRequestLimit rate.Limit
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorSubscriptionLimit int
BehindProxy bool
}
// New instantiates a default new config
@ -48,9 +48,10 @@ func New(listenHTTP string) *Config {
CacheDuration: DefaultCacheDuration,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
GlobalTopicLimit: defaultGlobalTopicLimit,
VisitorRequestLimit: defaultVisitorRequestLimit,
VisitorRequestLimitBurst: defaultVisitorRequestLimitBurst,
VisitorSubscriptionLimit: defaultVisitorSubscriptionLimit,
GlobalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
BehindProxy: false,
}
}

View file

@ -25,6 +25,30 @@
#
# keepalive-interval: 30s
# Interval in which the manager prunes old messages, deletes topics and prints the stats.
# Interval in which the manager prunes old messages, deletes topics
# and prints the stats.
#
# manager-interval: 1m
# Rate limiting: Total number of topics before the server rejects new topics.
#
# global-topic-limit: 5000
# Rate limiting: Number of subscriptions per visitor (IP address)
#
# visitor-subscription-limit: 30
# Rate limiting: Allowed GET/PUT/POST requests per second, per visitor:
# - visitor-request-limit-burst is the initial bucket of requests each visitor has
# - visitor-request-limit-replenish is the rate at which the bucket is refilled
#
# visitor-request-limit-burst: 60
# visitor-request-limit-replenish: 10s
# If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection.
#
# WARNING: If you are behind a proxy, you must set this, otherwise all visitors are rate limited
# as if they are one.
#
# behind-proxy: false

View file

@ -159,24 +159,22 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
v := s.visitor(r.RemoteAddr)
if err := v.RequestAllowed(); err != nil {
return err
}
if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.handleHome(w, r)
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
return s.handleEmpty(w, r)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
return s.handleStatic(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.handlePublish(w, r, v)
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
return s.handleSubscribeJSON(w, r, v)
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
return s.handleSubscribeSSE(w, r, v)
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
return s.handleSubscribeRaw(w, r, v)
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish)
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeJSON)
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeSSE)
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeRaw)
}
return errHTTPNotFound
}
@ -186,6 +184,10 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
return err
}
func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
return nil
@ -394,15 +396,27 @@ func (s *Server) updateStatsAndExpire() {
s.messages, len(s.topics), subscribers, messages, len(s.visitors))
}
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
v := s.visitor(r)
if err := v.RequestAllowed(); err != nil {
return err
}
return handler(w, r, v)
}
// visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(remoteAddr string) *visitor {
func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr // This should not happen in real life; only in tests.
}
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For")
}
v, exists := s.visitors[ip]
if !exists {
s.visitors[ip] = newVisitor(s.config)

View file

@ -24,7 +24,7 @@ type visitor struct {
func newVisitor(conf *config.Config) *visitor {
return &visitor{
config: conf,
limiter: rate.NewLimiter(conf.VisitorRequestLimit, conf.VisitorRequestLimitBurst),
limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
seen: time.Now(),
}