mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-14 11:47:33 +00:00
Delayed deletion
This commit is contained in:
parent
9c082a8331
commit
954d919361
14 changed files with 280 additions and 131 deletions
|
@ -16,8 +16,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
tierReset = "-"
|
||||
createdByCLI = "cli"
|
||||
tierReset = "-"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -197,7 +196,7 @@ func execUserAdd(c *cli.Context) error {
|
|||
|
||||
password = p
|
||||
}
|
||||
if err := manager.AddUser(username, password, role, createdByCLI); err != nil {
|
||||
if err := manager.AddUser(username, password, role); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(c.App.ErrWriter, "user %s added with role %s\n", username, role)
|
||||
|
|
|
@ -39,12 +39,10 @@ TODO
|
|||
--
|
||||
|
||||
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
|
||||
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- Stripe: Add metadata to customer
|
||||
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
- Logging: Add detailed logging with username/customerID for all Stripe events (phil)
|
||||
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
|
||||
- Stripe: Add metadata to customer
|
||||
|
||||
races:
|
||||
- v.user --> see publishSyncEventAsync() test
|
||||
|
@ -53,7 +51,7 @@ payments:
|
|||
- reconciliation
|
||||
|
||||
delete messages + reserved topics on ResetTier delete attachments in access.go
|
||||
account deletion should delete messages and reservations and attachments
|
||||
|
||||
|
||||
Limits & rate limiting:
|
||||
rate limiting weirdness. wth is going on?
|
||||
|
@ -1256,11 +1254,14 @@ func (s *Server) execManager() {
|
|||
s.mu.Unlock()
|
||||
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
|
||||
|
||||
// Delete expired user tokens
|
||||
// Delete expired user tokens and users
|
||||
if s.userManager != nil {
|
||||
if err := s.userManager.RemoveExpiredTokens(); err != nil {
|
||||
log.Warn("Error expiring user tokens: %s", err.Error())
|
||||
}
|
||||
if err := s.userManager.RemoveDeletedUsers(); err != nil {
|
||||
log.Warn("Error deleting soft-deleted users: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Delete expired attachments
|
||||
|
@ -1283,7 +1284,7 @@ func (s *Server) execManager() {
|
|||
}
|
||||
}
|
||||
|
||||
// DeleteMessages message cache
|
||||
// Prune messages
|
||||
log.Debug("Manager: Pruning messages")
|
||||
expiredMessageIDs, err := s.messageCache.MessagesExpired()
|
||||
if err != nil {
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
const (
|
||||
subscriptionIDLength = 16
|
||||
createdByAPI = "api"
|
||||
syncTopicAccountSyncEvent = "sync"
|
||||
)
|
||||
|
||||
|
@ -34,7 +33,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
||||
return errHTTPTooManyRequestsLimitAccountCreation
|
||||
}
|
||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
|
||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
|
@ -118,18 +117,20 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
|
||||
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Info("Deleting user %s", v.user.Name)
|
||||
if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
|
||||
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
|
||||
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
|
|
|
@ -67,8 +67,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
|
|||
conf.EnableSignup = true
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -133,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) {
|
|||
|
||||
func TestAccount_ChangeSettings(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
user, _ := s.userManager.User("phil")
|
||||
token, _ := s.userManager.CreateToken(user)
|
||||
|
||||
|
@ -160,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
|||
|
||||
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -210,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
|||
|
||||
func TestAccount_ChangePassword(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -237,7 +237,7 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
|||
|
||||
func TestAccount_ExtendToken(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -260,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) {
|
|||
|
||||
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"), // Not Bearer!
|
||||
|
@ -271,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
|||
|
||||
func TestAccount_DeleteToken(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -324,10 +324,15 @@ func TestAccount_Delete_Success(t *testing.T) {
|
|||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Account was marked deleted
|
||||
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Cannot re-create account, since still exists
|
||||
rr = request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||
require.Equal(t, 409, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccount_Delete_Not_Allowed(t *testing.T) {
|
||||
|
@ -360,7 +365,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
|||
conf := newTestConfigWithAuthFile(t)
|
||||
conf.EnableSignup = true
|
||||
s := newTestServer(t, conf)
|
||||
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "adminpass"),
|
||||
|
|
|
@ -18,15 +18,10 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
|
||||
)
|
||||
|
||||
// Payments in ntfy are done via Stripe.
|
||||
//
|
||||
// Pretty much all payments related things are in this file. The following processes
|
||||
|
@ -49,6 +44,16 @@ var (
|
|||
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
|
||||
// What Stripe says is mirrored and not questioned.
|
||||
|
||||
var (
|
||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
|
||||
)
|
||||
|
||||
var (
|
||||
retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second}
|
||||
)
|
||||
|
||||
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
|
||||
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
|
||||
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
|
@ -114,7 +119,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
} else if tier.StripePriceID == "" {
|
||||
return errNotAPaidTier
|
||||
}
|
||||
log.Info("Stripe: No existing subscription, creating checkout flow")
|
||||
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
||||
var stripeCustomerID *string
|
||||
if v.user.Billing.StripeCustomerID != "" {
|
||||
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
||||
|
@ -138,9 +143,6 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
/*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
|
||||
Enabled: stripe.Bool(true),
|
||||
},*/
|
||||
}
|
||||
sess, err := s.stripe.NewCheckoutSession(params)
|
||||
if err != nil {
|
||||
|
@ -155,7 +157,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
|
||||
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
|
||||
// and only time we can map the local username with the Stripe customer ID.
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
// We don't have a v.user in this endpoint, only a userManager!
|
||||
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
|
||||
if len(matches) != 2 {
|
||||
|
@ -182,7 +184,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
v.SetUser(u)
|
||||
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
||||
|
@ -203,7 +206,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
|
||||
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
|
||||
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -228,6 +231,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||
// and cancelling the Stripe subscription entirely
|
||||
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(true),
|
||||
|
@ -246,6 +250,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|||
if v.user.Billing.StripeCustomerID == "" {
|
||||
return errHTTPBadRequestNotAPaidUser
|
||||
}
|
||||
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
||||
ReturnURL: stripe.String(s.config.BaseURL),
|
||||
|
@ -280,28 +285,30 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
} else if event.Data == nil || event.Data.Raw == nil {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
|
||||
log.Info("Stripe: webhook event %s received", event.Type)
|
||||
switch event.Type {
|
||||
case "customer.subscription.updated":
|
||||
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
|
||||
case "customer.subscription.deleted":
|
||||
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
|
||||
default:
|
||||
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
|
||||
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
|
||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
||||
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
|
||||
userFn := func() (*user.User, error) {
|
||||
return s.userManager.UserByStripeCustomer(ev.Customer)
|
||||
}
|
||||
u, err := util.Retry[user.User](userFn, retryUserDelays...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -309,7 +316,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -317,47 +324,56 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.Customer == "" {
|
||||
} else if ev.Customer == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
|
||||
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
// Remove excess reservations (if too many for tier), and mark associated messages deleted
|
||||
// maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
|
||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||
func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) <= reservationsLimit {
|
||||
return nil
|
||||
}
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
reservationsLimit := visitorDefaultReservationsLimit
|
||||
if tier != nil {
|
||||
reservationsLimit = tier.ReservationsLimit
|
||||
}
|
||||
if int64(len(reservations)) > reservationsLimit {
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
// Change or remove tier
|
||||
if tier == nil {
|
||||
if err := s.userManager.ResetTier(u.Name); err != nil {
|
||||
return err
|
||||
|
|
|
@ -34,7 +34,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
|||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
// Create subscription
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||
|
@ -69,7 +69,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
@ -110,7 +110,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
@ -174,7 +174,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
|
||||
|
|
|
@ -625,7 +625,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
|
|||
c := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -639,7 +639,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
|
@ -653,7 +653,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
|
||||
|
||||
|
@ -674,7 +674,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "INVALID"),
|
||||
|
@ -687,7 +687,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
|
@ -701,7 +701,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionReadWrite // Open by default
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||
|
||||
|
@ -731,7 +731,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
|
||||
|
||||
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass"))))
|
||||
response := request(t, s, "GET", u, "", nil)
|
||||
|
@ -749,7 +749,7 @@ func TestServer_StatsResetter(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
go s.runStatsResetter()
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
|
@ -1137,7 +1137,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
|
|||
MessagesLimit: 5,
|
||||
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish to reach message limit
|
||||
|
@ -1369,7 +1369,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: sevenDays, // 7 days
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish and make sure we can retrieve it
|
||||
|
@ -1413,7 +1413,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
|||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: 30 * time.Second,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish small file as anonymous
|
||||
|
|
|
@ -354,5 +354,6 @@ type apiStripeSubscriptionUpdatedEvent struct {
|
|||
}
|
||||
|
||||
type apiStripeSubscriptionDeletedEvent struct {
|
||||
ID string `json:"id"`
|
||||
Customer string `json:"customer"`
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ func readQueryParam(r *http.Request, names ...string) string {
|
|||
}
|
||||
|
||||
func logMessagePrefix(v *visitor, m *message) string {
|
||||
return fmt.Sprintf("%s/%s/%s", v.ip, m.Topic, m.ID)
|
||||
return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID)
|
||||
}
|
||||
|
||||
func logHTTPPrefix(v *visitor, r *http.Request) string {
|
||||
|
@ -57,7 +57,14 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
|
|||
if requestURI == "" {
|
||||
requestURI = r.URL.Path
|
||||
}
|
||||
return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI)
|
||||
return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI)
|
||||
}
|
||||
|
||||
func logStripePrefix(customerID, subscriptionID string) string {
|
||||
if subscriptionID != "" {
|
||||
return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID)
|
||||
}
|
||||
return fmt.Sprintf("%s STRIPE", customerID)
|
||||
}
|
||||
|
||||
func logSMTPPrefix(state *smtp.ConnectionState) string {
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/user"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
@ -119,6 +120,17 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
|||
}
|
||||
}
|
||||
|
||||
func (v *visitor) String() string {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
|
||||
} else if v.user != nil {
|
||||
return fmt.Sprintf("%s/%s", v.ip.String(), v.user.ID)
|
||||
}
|
||||
return v.ip.String()
|
||||
}
|
||||
|
||||
func (v *visitor) RequestAllowed() error {
|
||||
if !v.requestLimiter.Allow() {
|
||||
return errVisitorLimitReached
|
||||
|
@ -216,6 +228,12 @@ func (v *visitor) ResetStats() {
|
|||
}
|
||||
}
|
||||
|
||||
func (v *visitor) SetUser(u *user.User) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
v.user = u
|
||||
}
|
||||
|
||||
func (v *visitor) Limits() *visitorLimits {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
|
|
@ -25,6 +25,7 @@ const (
|
|||
userPasswordBcryptCost = 10
|
||||
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
|
||||
userStatsQueueWriterInterval = 33 * time.Second
|
||||
userHardDeleteAfterDuration = 7 * 24 * time.Hour
|
||||
tokenPrefix = "tk_"
|
||||
tokenLength = 32
|
||||
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
||||
|
@ -57,7 +58,7 @@ const (
|
|||
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id INT,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
|
@ -70,8 +71,8 @@ const (
|
|||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created_by TEXT NOT NULL,
|
||||
created_at INT NOT NULL,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
|
@ -98,8 +99,8 @@ const (
|
|||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH())
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
|
||||
|
@ -108,26 +109,26 @@ const (
|
|||
`
|
||||
|
||||
selectUserByIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
selectUserByNameQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
WHERE user = ?
|
||||
`
|
||||
selectUserByTokenQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
FROM user u
|
||||
JOIN user_token t on u.id = t.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE t.token = ? AND t.expires >= ?
|
||||
`
|
||||
selectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
|
@ -141,8 +142,8 @@ const (
|
|||
`
|
||||
|
||||
insertUserQuery = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
selectUsernamesQuery = `
|
||||
SELECT user
|
||||
|
@ -159,6 +160,8 @@ const (
|
|||
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
|
||||
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
|
||||
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
|
||||
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||
deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||
|
||||
upsertUserAccessQuery = `
|
||||
|
@ -214,7 +217,8 @@ const (
|
|||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
|
@ -268,8 +272,8 @@ const (
|
|||
`
|
||||
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
migrate1To2InsertUserNoTx = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
|
||||
SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
|
@ -320,9 +324,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, stats
|
|||
return manager, nil
|
||||
}
|
||||
|
||||
// Authenticate checks username and password and returns a User if correct. The method
|
||||
// returns in constant-ish time, regardless of whether the user exists or the password is
|
||||
// correct or incorrect.
|
||||
// Authenticate checks username and password and returns a User if correct, and the user has not been
|
||||
// marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
|
||||
// the password is correct or incorrect.
|
||||
func (a *Manager) Authenticate(username, password string) (*User, error) {
|
||||
if username == Everyone {
|
||||
return nil, ErrUnauthenticated
|
||||
|
@ -332,9 +336,12 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
|
|||
log.Trace("authentication of user %s failed (1): %s", username, err.Error())
|
||||
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
|
||||
return nil, ErrUnauthenticated
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
|
||||
log.Trace("authentication of user %s failed (2): %s", username, err.Error())
|
||||
} else if user.Deleted {
|
||||
log.Trace("authentication of user %s failed (2): user marked deleted", username)
|
||||
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
|
||||
return nil, ErrUnauthenticated
|
||||
} else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
|
||||
log.Trace("authentication of user %s failed (3): %s", username, err.Error())
|
||||
return nil, ErrUnauthenticated
|
||||
}
|
||||
return user, nil
|
||||
|
@ -415,7 +422,7 @@ func (a *Manager) RemoveToken(user *User) error {
|
|||
if user.Token == "" {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
|
||||
if _, err := a.db.Exec(deleteTokenQuery, user.ID, user.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -429,6 +436,14 @@ func (a *Manager) RemoveExpiredTokens() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RemoveDeletedUsers deletes all users that have been marked deleted for
|
||||
func (a *Manager) RemoveDeletedUsers() error {
|
||||
if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeSettings persists the user settings
|
||||
func (a *Manager) ChangeSettings(user *User) error {
|
||||
prefs, err := json.Marshal(user.Prefs)
|
||||
|
@ -533,7 +548,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
|
|||
}
|
||||
|
||||
// AddUser adds a user with the given username, password and role
|
||||
func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
|
||||
func (a *Manager) AddUser(username, password string, role Role) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
|
@ -543,7 +558,7 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string
|
|||
}
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
|
||||
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil {
|
||||
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -562,6 +577,29 @@ func (a *Manager) RemoveUser(username string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
|
||||
// successful auth via Authenticate. A background process will delete the user at a later date.
|
||||
func (a *Manager) MarkUserRemoved(user *User) error {
|
||||
if !AllowedUsername(user.Name) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// Users returns a list of users. It always also returns the Everyone user ("*").
|
||||
func (a *Manager) Users() ([]*User, error) {
|
||||
rows, err := a.db.Query(selectUsernamesQuery)
|
||||
|
@ -632,11 +670,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
var id, username, hash, role, prefs, syncTopic string
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
|
||||
var messages, emails int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -659,6 +697,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
|
||||
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
|
||||
},
|
||||
Deleted: deleted.Valid,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -13,8 +13,8 @@ const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this sh
|
|||
|
||||
func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
|
||||
|
@ -92,20 +92,20 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
|||
|
||||
func TestManager_AddUser_Invalid(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, "unit-test"))
|
||||
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", "unit-test"))
|
||||
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin))
|
||||
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role"))
|
||||
}
|
||||
|
||||
func TestManager_AddUser_Timing(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
start := time.Now().UnixMilli()
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
|
||||
}
|
||||
|
||||
func TestManager_Authenticate_Timing(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||
|
||||
// Timing a correct attempt
|
||||
start := time.Now().UnixMilli()
|
||||
|
@ -126,10 +126,60 @@ func TestManager_Authenticate_Timing(t *testing.T) {
|
|||
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
|
||||
}
|
||||
|
||||
func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
|
||||
// Create user, add reservations and token
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead))
|
||||
|
||||
u, err := a.User("user")
|
||||
require.Nil(t, err)
|
||||
require.False(t, u.Deleted)
|
||||
|
||||
token, err := a.CreateToken(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
u, err = a.Authenticate("user", "pass")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = a.AuthenticateToken(token.Value)
|
||||
require.Nil(t, err)
|
||||
|
||||
reservations, err := a.Reservations("user")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(reservations))
|
||||
|
||||
// Mark deleted: cannot auth anymore, and all reservations are gone
|
||||
require.Nil(t, a.MarkUserRemoved(u))
|
||||
|
||||
_, err = a.Authenticate("user", "pass")
|
||||
require.Equal(t, ErrUnauthenticated, err)
|
||||
|
||||
_, err = a.AuthenticateToken(token.Value)
|
||||
require.Equal(t, ErrUnauthenticated, err)
|
||||
|
||||
reservations, err = a.Reservations("user")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(reservations))
|
||||
|
||||
// Make sure user is still there
|
||||
u, err = a.User("user")
|
||||
require.Nil(t, err)
|
||||
require.True(t, u.Deleted)
|
||||
|
||||
_, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, a.RemoveDeletedUsers())
|
||||
|
||||
_, err = a.User("user")
|
||||
require.Equal(t, ErrUserNotFound, err)
|
||||
}
|
||||
|
||||
func TestManager_UserManagement(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
|
||||
|
@ -219,7 +269,7 @@ func TestManager_UserManagement(t *testing.T) {
|
|||
|
||||
func TestManager_ChangePassword(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
|
||||
|
||||
_, err := a.Authenticate("phil", "phil")
|
||||
require.Nil(t, err)
|
||||
|
@ -233,7 +283,7 @@ func TestManager_ChangePassword(t *testing.T) {
|
|||
|
||||
func TestManager_ChangeRole(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
|
||||
|
@ -258,7 +308,7 @@ func TestManager_ChangeRole(t *testing.T) {
|
|||
|
||||
func TestManager_Reservations(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
|
||||
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
|
||||
|
@ -292,7 +342,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||
AttachmentTotalSizeLimit: 524288000,
|
||||
AttachmentExpiryDuration: 24 * time.Hour,
|
||||
}))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
require.Nil(t, a.ChangeTier("ben", "pro"))
|
||||
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
|
||||
|
||||
|
@ -340,7 +390,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||
|
||||
func TestManager_Token_Valid(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
u, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
@ -365,7 +415,7 @@ func TestManager_Token_Valid(t *testing.T) {
|
|||
|
||||
func TestManager_Token_Invalid(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length
|
||||
require.Nil(t, u)
|
||||
|
@ -378,7 +428,7 @@ func TestManager_Token_Invalid(t *testing.T) {
|
|||
|
||||
func TestManager_Token_Expire(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
u, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
@ -426,7 +476,7 @@ func TestManager_Token_Expire(t *testing.T) {
|
|||
|
||||
func TestManager_Token_Extend(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
// Try to extend token for user without token
|
||||
u, err := a.User("ben")
|
||||
|
@ -453,7 +503,7 @@ func TestManager_Token_Extend(t *testing.T) {
|
|||
|
||||
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
// Try to extend token for user without token
|
||||
u, err := a.User("ben")
|
||||
|
@ -497,7 +547,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
|||
func TestManager_EnqueueStats(t *testing.T) {
|
||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
// Baseline: No messages or emails
|
||||
u, err := a.User("ben")
|
||||
|
@ -527,7 +577,7 @@ func TestManager_EnqueueStats(t *testing.T) {
|
|||
func TestManager_ChangeSettings(t *testing.T) {
|
||||
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||
|
||||
// No settings
|
||||
u, err := a.User("ben")
|
||||
|
|
|
@ -20,8 +20,7 @@ type User struct {
|
|||
Stats *Stats
|
||||
Billing *Billing
|
||||
SyncTopic string
|
||||
Created time.Time
|
||||
LastSeen time.Time
|
||||
Deleted bool
|
||||
}
|
||||
|
||||
// Auther is an interface for authentication and authorization
|
||||
|
@ -186,7 +185,8 @@ const (
|
|||
|
||||
// Everyone is a special username representing anonymous users
|
||||
const (
|
||||
Everyone = "*"
|
||||
Everyone = "*"
|
||||
everyoneID = "u_everyone"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
12
util/util.go
12
util/util.go
|
@ -324,3 +324,15 @@ func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) {
|
|||
}
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
// Retry executes function f until if succeeds, and then returns t. If f fails, it sleeps
|
||||
// and tries again. The sleep durations are passed as the after params.
|
||||
func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error) {
|
||||
for _, delay := range after {
|
||||
if t, err = f(); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue