1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-15 17:50:55 +00:00

Auth rate limiter

This commit is contained in:
binwiederhier 2023-02-08 15:20:44 -05:00
parent 3ac315a9e7
commit e1a4a74905
16 changed files with 152 additions and 60 deletions

View file

@ -79,7 +79,9 @@ user * (role: anonymous, tier: none)
func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error { func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{ userArgs := []string{
"ntfy", "ntfy",
"--log-level=ERROR",
"access", "access",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile, "--auth-file=" + conf.AuthFile,
"--auth-default-access=" + conf.AuthDefault.String(), "--auth-default-access=" + conf.AuthDefault.String(),
} }

View file

@ -253,6 +253,7 @@ func execServe(c *cli.Context) error {
// Run server // Run server
conf := server.NewConfig() conf := server.NewConfig()
conf.File = config
conf.BaseURL = baseURL conf.BaseURL = baseURL
conf.ListenHTTP = listenHTTP conf.ListenHTTP = listenHTTP
conf.ListenHTTPS = listenHTTPS conf.ListenHTTPS = listenHTTPS

View file

@ -38,7 +38,9 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
func runTierCommand(app *cli.App, conf *server.Config, args ...string) error { func runTierCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{ userArgs := []string{
"ntfy", "ntfy",
"--log-level=ERROR",
"tier", "tier",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile, "--auth-file=" + conf.AuthFile,
"--auth-default-access=" + conf.AuthDefault.String(), "--auth-default-access=" + conf.AuthDefault.String(),
} }

View file

@ -41,7 +41,9 @@ func TestCLI_Token_AddListRemove(t *testing.T) {
func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error { func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{ userArgs := []string{
"ntfy", "ntfy",
"--log-level=ERROR",
"token", "token",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile, "--auth-file=" + conf.AuthFile,
} }
return app.Run(append(userArgs, args...)) return app.Run(append(userArgs, args...))

View file

@ -6,6 +6,7 @@ import (
"heckel.io/ntfy/server" "heckel.io/ntfy/server"
"heckel.io/ntfy/test" "heckel.io/ntfy/test"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"os"
"path/filepath" "path/filepath"
"testing" "testing"
) )
@ -113,7 +114,10 @@ func TestCLI_User_Delete(t *testing.T) {
} }
func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) { func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) {
configFile := filepath.Join(t.TempDir(), "server-dummy.yml")
require.Nil(t, os.WriteFile(configFile, []byte(""), 0600)) // Dummy config file to avoid lookup of real server.yml
conf = server.NewConfig() conf = server.NewConfig()
conf.File = configFile
conf.AuthFile = filepath.Join(t.TempDir(), "user.db") conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
conf.AuthDefault = user.PermissionDenyAll conf.AuthDefault = user.PermissionDenyAll
s, port = test.StartServerWithConfig(t, conf) s, port = test.StartServerWithConfig(t, conf)
@ -123,7 +127,9 @@ func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config,
func runUserCommand(app *cli.App, conf *server.Config, args ...string) error { func runUserCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{ userArgs := []string{
"ntfy", "ntfy",
"--log-level=ERROR",
"user", "user",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile, "--auth-file=" + conf.AuthFile,
"--auth-default-access=" + conf.AuthDefault.String(), "--auth-default-access=" + conf.AuthDefault.String(),
} }

View file

@ -82,8 +82,10 @@ func (e *Event) Time(t time.Time) *Event {
// Err adds an "error" field to the log event // Err adds an "error" field to the log event
func (e *Event) Err(err error) *Event { func (e *Event) Err(err error) *Event {
if c, ok := err.(Contexter); ok { if err == nil {
return e.Fields(c.Context()) return e
} else if c, ok := err.(Contexter); ok {
return e.With(c)
} }
return e.Field(errorField, err.Error()) return e.Field(errorField, err.Error())
} }

View file

@ -49,6 +49,8 @@ const (
DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAccountCreationLimitBurst = 3 DefaultVisitorAccountCreationLimitBurst = 3
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
DefaultVisitorAuthFailureLimitBurst = 10
DefaultVisitorAuthFailureLimitReplenish = time.Minute
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
) )
@ -60,6 +62,7 @@ var (
// Config is the main config struct for the application. Use New to instantiate a default config struct. // Config is the main config struct for the application. Use New to instantiate a default config struct.
type Config struct { type Config struct {
File string // Config file, only used for testing
BaseURL string BaseURL string
ListenHTTP string ListenHTTP string
ListenHTTPS string ListenHTTPS string
@ -113,6 +116,8 @@ type Config struct {
VisitorEmailLimitReplenish time.Duration VisitorEmailLimitReplenish time.Duration
VisitorAccountCreationLimitBurst int VisitorAccountCreationLimitBurst int
VisitorAccountCreationLimitReplenish time.Duration VisitorAccountCreationLimitReplenish time.Duration
VisitorAuthFailureLimitBurst int
VisitorAuthFailureLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool BehindProxy bool
StripeSecretKey string StripeSecretKey string
@ -129,6 +134,7 @@ type Config struct {
// NewConfig instantiates a default new server config // NewConfig instantiates a default new server config
func NewConfig() *Config { func NewConfig() *Config {
return &Config{ return &Config{
File: "", // Only used for testing
BaseURL: "", BaseURL: "",
ListenHTTP: DefaultListenHTTP, ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "", ListenHTTPS: "",
@ -182,6 +188,8 @@ func NewConfig() *Config {
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst, VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish, VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
VisitorAuthFailureLimitBurst: DefaultVisitorAuthFailureLimitBurst,
VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish,
VisitorStatsResetTime: DefaultVisitorStatsResetTime, VisitorStatsResetTime: DefaultVisitorStatsResetTime,
BehindProxy: false, BehindProxy: false,
StripeSecretKey: "", StripeSecretKey: "",

View file

@ -87,6 +87,7 @@ var (
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""} errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}

View file

@ -34,9 +34,9 @@ import (
/* /*
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Account limit creation triggers when account is taken!
- HIGH Docs - HIGH Docs
- tiers
- api
- HIGH Self-review - HIGH Self-review
- MEDIUM: Test for expiring messages after reservation removal - MEDIUM: Test for expiring messages after reservation removal
- MEDIUM: Test new token endpoints & never-expiring token - MEDIUM: Test new token endpoints & never-expiring token
@ -1540,18 +1540,6 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
return nil return nil
} }
func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil {
logvr(v, r).Err(err).Trace("Request not allowed by rate limiter")
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc { func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
@ -1648,43 +1636,65 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
} }
} }
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor. // maybeAuthenticate reads the "Authorization" header and will try to authenticate the user
// Note that this function will always return a visitor, even if an error occurs. // if it is set.
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) { //
// - If the header is not set, an IP-based visitor is returned
// - If the header is set, authenticate will be called to check the username/password (Basic auth),
// or the token (Bearer auth), and read the user from the database
//
// This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
// that subsequent logging calls still have a visitor context.
func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
// Read "Authorization" header value, and exit out early if it's not set
ip := extractIPAddress(r, s.config.BehindProxy) ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header! vip := s.visitor(ip, nil)
if u, err = s.authenticate(r); err != nil { header, err := readAuthHeader(r)
logr(r).Err(err).Debug("Authentication failed: %s", err.Error()) if err != nil {
err = errHTTPUnauthorized // Always return visitor, even when error occurs! return vip, err
} else if header == "" {
return vip, nil
} else if s.userManager == nil {
return vip, errHTTPUnauthorized
} }
v = s.visitor(ip, u) // If we're trying to auth, check the rate limiter first
v.SetUser(u) // Update visitor user with latest from database! if !vip.AuthAllowed() {
return v, err // Always return visitor, even when error occurs! return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
}
u, err := s.authenticate(r, header)
if err != nil {
vip.AuthFailed()
logr(r).Err(err).Debug("Authentication failed")
return vip, errHTTPUnauthorized // Always return visitor, even when error occurs!
}
// Authentication with user was successful
return s.visitor(ip, u), nil
} }
// authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...). // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)). // query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request) (user *user.User, err error) { func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) {
if strings.HasPrefix(header, "Bearer") {
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer")))
}
return s.authenticateBasicAuth(r, header)
}
// readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header,
// or from the ?auth... query parameter
func readAuthHeader(r *http.Request) (string, error) {
value := strings.TrimSpace(r.Header.Get("Authorization")) value := strings.TrimSpace(r.Header.Get("Authorization"))
queryParam := readQueryParam(r, "authorization", "auth") queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" { if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam) a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil { if err != nil {
return nil, err return "", err
} }
value = strings.TrimSpace(string(a)) value = strings.TrimSpace(string(a))
} }
if value == "" { return value, nil
return nil, nil
} else if s.userManager == nil {
return nil, errHTTPUnauthorized
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer")))
}
return s.authenticateBasicAuth(r, value)
} }
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) { func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
@ -1721,6 +1731,7 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
return s.visitors[id] return s.visitors[id]
} }
v.Keepalive() v.Keepalive()
v.SetUser(user) // Always update with the latest user, may be nil!
return v return v
} }

View file

@ -41,6 +41,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil {
return err return err
} }
v.AccountCreated()
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())
} }

View file

@ -39,7 +39,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
} }
func (c *firebaseClient) Send(v *visitor, m *message) error { func (c *firebaseClient) Send(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil { if !v.FirebaseAllowed() {
return errFirebaseTemporarilyBanned return errFirebaseTemporarilyBanned
} }
fbm, err := toFirebaseMessage(m, c.auther) fbm, err := toFirebaseMessage(m, c.auther)

View file

@ -1,9 +1,21 @@
package server package server
import ( import (
"heckel.io/ntfy/util"
"net/http" "net/http"
) )
func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if !v.RequestAllowed() {
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc { func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error { return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnableWeb { if !s.config.EnableWeb {

View file

@ -374,13 +374,13 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 209; i++ { for i := 0; i < 209; i++ {
wg.Add(1) wg.Add(1)
go func() { go func(i int) {
defer wg.Done()
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code, "Failed on %d", i)
wg.Done() }(i)
}()
} }
wg.Wait() wg.Wait()
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{ rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{

View file

@ -733,6 +733,24 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
require.Equal(t, 403, response.Code) // Anonymous read not allowed require.Equal(t, 403, response.Code) // Anonymous read not allowed
} }
func TestServer_Auth_Fail_Rate_Limiting(t *testing.T) {
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
for i := 0; i < 10; i++ {
response := request(t, s, "PUT", "/announcements", "test", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 401, response.Code)
}
response := request(t, s, "PUT", "/announcements", "test", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, response.Code)
require.Equal(t, 42909, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Auth_ViaQuery(t *testing.T) { func TestServer_Auth_ViaQuery(t *testing.T) {
c := newTestConfigWithAuthFile(t) c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll c.AuthDefault = user.PermissionDenyAll

View file

@ -64,6 +64,7 @@ type visitor struct {
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections) subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
authLimiter *rate.Limiter // Limiter for incorrect login attempts
firebase time.Time // Next allowed Firebase message firebase time.Time // Next allowed Firebase message
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
mu sync.Mutex mu sync.Mutex
@ -130,6 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
emailsLimiter: nil, // Set in resetLimiters emailsLimiter: nil, // Set in resetLimiters
bandwidthLimiter: nil, // Set in resetLimiters bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil accountLimiter: nil, // Set in resetLimiters, may be nil
authLimiter: nil, // Set in resetLimiters, may be nil
} }
v.resetLimitersNoLock(messages, emails, false) v.resetLimitersNoLock(messages, emails, false)
return v return v
@ -154,6 +156,10 @@ func (v *visitor) contextNoLock() log.Context {
"visitor_request_limiter_limit": v.requestLimiter.Limit(), "visitor_request_limiter_limit": v.requestLimiter.Limit(),
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(), "visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
} }
if v.authLimiter != nil {
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
}
if v.user != nil { if v.user != nil {
fields["user_id"] = v.user.ID fields["user_id"] = v.user.ID
fields["user_name"] = v.user.Name fields["user_name"] = v.user.Name
@ -182,28 +188,16 @@ func visitorExtendedInfoContext(info *visitorInfo) log.Context {
} }
} }
func (v *visitor) RequestAllowed() error { func (v *visitor) RequestAllowed() bool {
v.mu.Lock() // limiters could be replaced! v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock() defer v.mu.Unlock()
if !v.requestLimiter.Allow() { return v.requestLimiter.Allow()
return errVisitorLimitReached
}
return nil
} }
func (v *visitor) RequestLimiter() *rate.Limiter { func (v *visitor) FirebaseAllowed() bool {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
return v.requestLimiter
}
func (v *visitor) FirebaseAllowed() error {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
if time.Now().Before(v.firebase) { return !time.Now().Before(v.firebase)
return errVisitorLimitReached
}
return nil
} }
func (v *visitor) FirebaseTemporarilyDeny() { func (v *visitor) FirebaseTemporarilyDeny() {
@ -230,15 +224,44 @@ func (v *visitor) SubscriptionAllowed() bool {
return v.subscriptionLimiter.Allow() return v.subscriptionLimiter.Allow()
} }
// AuthAllowed returns true if an auth request can be attempted (> 1 token available)
func (v *visitor) AuthAllowed() bool {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.authLimiter == nil {
return true
}
return v.authLimiter.Tokens() > 1
}
// AuthFailed records an auth failure
func (v *visitor) AuthFailed() {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.authLimiter != nil {
v.authLimiter.Allow()
}
}
// AccountCreationAllowed returns true if a new account can be created
func (v *visitor) AccountCreationAllowed() bool { func (v *visitor) AccountCreationAllowed() bool {
v.mu.Lock() // limiters could be replaced! v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock() defer v.mu.Unlock()
if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) { if v.accountLimiter == nil || (v.accountLimiter != nil && v.accountLimiter.Tokens() < 1) {
return false return false
} }
return true return true
} }
// AccountCreated decreases the account limiter. This is to be called after an account was created.
func (v *visitor) AccountCreated() {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.accountLimiter != nil {
v.accountLimiter.Allow()
}
}
func (v *visitor) BandwidthAllowed(bytes int64) bool { func (v *visitor) BandwidthAllowed(bytes int64) bool {
v.mu.Lock() // limiters could be replaced! v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock() defer v.mu.Unlock()
@ -336,8 +359,10 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay) v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil { if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
v.authLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAuthFailureLimitReplenish), v.config.VisitorAuthFailureLimitBurst)
} else { } else {
v.accountLimiter = nil // Users cannot create accounts when logged in v.accountLimiter = nil // Users cannot create accounts when logged in
v.authLimiter = nil // Users are already logged in, no need to limit requests
} }
if enqueueUpdate && v.user != nil { if enqueueUpdate && v.user != nil {
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{ go v.userManager.EnqueueStats(v.user.ID, &user.Stats{

View file

@ -372,6 +372,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
} }
user, err := a.userByToken(token) user, err := a.userByToken(token)
if err != nil { if err != nil {
log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
} }
user.Token = token user.Token = token