1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-21 19:03:26 +01:00

Add bandwidth limit to tier; fix display name sync issues

This commit is contained in:
binwiederhier 2023-01-25 10:05:54 -05:00
parent 1771cb3fdb
commit 236254d907
13 changed files with 119 additions and 51 deletions

View file

@ -285,7 +285,7 @@ func execServe(c *cli.Context) error {
conf.TotalTopicLimit = totalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit
conf.VisitorAttachmentDailyBandwidthLimit = int(visitorAttachmentDailyBandwidthLimit)
conf.VisitorAttachmentDailyBandwidthLimit = visitorAttachmentDailyBandwidthLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs

View file

@ -101,7 +101,7 @@ type Config struct {
TotalAttachmentSizeLimit int64
VisitorSubscriptionLimit int
VisitorAttachmentTotalSizeLimit int64
VisitorAttachmentDailyBandwidthLimit int
VisitorAttachmentDailyBandwidthLimit int64
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []netip.Prefix

View file

@ -40,7 +40,6 @@ TODO
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Bandwidth limit must be in tier + TESTS
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -866,7 +865,6 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
}
fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
if err == util.ErrLimitReached {
return errHTTPEntityTooLargeAttachment

View file

@ -11,6 +11,7 @@ import (
const (
subscriptionIDLength = 16
subscriptionIDPrefix = "su_"
syncTopicAccountSyncEvent = "sync"
)
@ -55,6 +56,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit,
AttachmentExpiryDuration: int64(limits.AttachmentExpiryDuration.Seconds()),
AttachmentBandwidth: limits.AttachmentBandwidthLimit,
},
Stats: &apiAccountStats{
Messages: stats.Messages,
@ -249,7 +251,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
}
}
if newSubscription.ID == "" {
newSubscription.ID = util.RandomString(subscriptionIDLength)
newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription)
if err := s.userManager.ChangeSettings(v.user); err != nil {
return err

View file

@ -153,9 +153,9 @@ func TestAccount_ChangeSettings(t *testing.T) {
require.Equal(t, 200, rr.Code)
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
require.Equal(t, "de", account.Language)
require.Equal(t, 86400, account.Notification.DeleteAfter)
require.Equal(t, "juntos", account.Notification.Sound)
require.Equal(t, 0, account.Notification.MinPriority) // Not set
require.Equal(t, util.Int(86400), account.Notification.DeleteAfter)
require.Equal(t, util.String("juntos"), account.Notification.Sound)
require.Nil(t, account.Notification.MinPriority) // Not set
}
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
@ -176,7 +176,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
require.NotEmpty(t, account.Subscriptions[0].ID)
require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL)
require.Equal(t, "def", account.Subscriptions[0].Topic)
require.Equal(t, "", account.Subscriptions[0].DisplayName)
require.Nil(t, account.Subscriptions[0].DisplayName)
subscriptionID := account.Subscriptions[0].ID
rr = request(t, s, "PATCH", "/v1/account/subscription/"+subscriptionID, `{"display_name": "ding dong"}`, map[string]string{
@ -193,7 +193,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
require.Equal(t, subscriptionID, account.Subscriptions[0].ID)
require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL)
require.Equal(t, "def", account.Subscriptions[0].Topic)
require.Equal(t, "ding dong", account.Subscriptions[0].DisplayName)
require.Equal(t, util.String("ding dong"), account.Subscriptions[0].DisplayName)
rr = request(t, s, "DELETE", "/v1/account/subscription/"+subscriptionID, "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -402,6 +402,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second,
AttachmentBandwidthLimit: 21474836480,
}))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -442,6 +443,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize)
require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize)
require.Equal(t, int64(10800), account.Limits.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), account.Limits.AttachmentBandwidth)
require.Equal(t, 2, len(account.Reservations))
require.Equal(t, "another", account.Reservations[0].Topic)
require.Equal(t, "write-only", account.Reservations[0].Everyone)

View file

@ -265,6 +265,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
AttachmentBandwidthLimit: 1000000,
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
@ -275,6 +276,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
AttachmentBandwidthLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))

View file

@ -1368,6 +1368,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days
AttachmentBandwidthLimit: 100000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
@ -1376,6 +1377,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
response := request(t, s, "PUT", "/mytopic", content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
msg := toMessage(t, response.Body.String())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.True(t, msg.Attachment.Expires > time.Now().Add(sevenDays-30*time.Second).Unix())
@ -1396,6 +1398,46 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
require.Equal(t, 200, response.Code)
}
func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
// Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessagesLimit: 10,
MessagesExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: time.Hour,
AttachmentBandwidthLimit: 14000, // < 3x5000 bytes -> enough for one upload, one download
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it
rr := request(t, s, "PUT", "/mytopic", content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
msg := toMessage(t, rr.Body.String())
// Retrieve it (first time succeeds)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
require.Equal(t, content, rr.Body.String())
// Retrieve it AGAIN (fails, due to bandwidth limit)
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
}
func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
smallFile := util.RandomString(20_000)
largeFile := util.RandomString(50_000)
@ -1412,6 +1454,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second,
AttachmentBandwidthLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))

View file

@ -246,7 +246,7 @@ type apiAccountTier struct {
}
type apiAccountLimits struct {
Basis string `json:"basis,omitempty"` // "ip", "role" or "tier"
Basis string `json:"basis,omitempty"` // "ip" or "tier"
Messages int64 `json:"messages"`
MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
Emails int64 `json:"emails"`
@ -254,6 +254,7 @@ type apiAccountLimits struct {
AttachmentTotalSize int64 `json:"attachment_total_size"`
AttachmentFileSize int64 `json:"attachment_file_size"`
AttachmentExpiryDuration int64 `json:"attachment_expiry_duration"`
AttachmentBandwidth int64 `json:"attachment_bandwidth"`
}
type apiAccountStats struct {

View file

@ -31,9 +31,9 @@ var (
type visitor struct {
config *Config
messageCache *messageCache
userManager *user.Manager // May be nil!
ip netip.Addr
user *user.User
userManager *user.Manager // May be nil
ip netip.Addr // Visitor IP address
user *user.User // Only set if authenticated user, otherwise nil
messages int64 // Number of messages sent, reset every day
emails int64 // Number of emails sent, reset every day
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
@ -61,6 +61,7 @@ type visitorLimits struct {
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
AttachmentExpiryDuration time.Duration
AttachmentBandwidthLimit int64
}
type visitorStats struct {
@ -84,7 +85,7 @@ const (
)
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messagesLimiter util.Limiter
var messagesLimiter, attachmentBandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
var messages, emails int64
if user != nil {
@ -97,9 +98,11 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
} else {
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
}
return &visitor{
config: conf,
@ -113,7 +116,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
messagesLimiter: messagesLimiter, // May be nil
emailsLimiter: emailsLimiter,
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidthLimiter: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
bandwidthLimiter: attachmentBandwidthLimiter,
accountLimiter: accountLimiter, // May be nil
firebase: time.Unix(0, 0),
seen: time.Now(),
@ -259,6 +262,7 @@ func (v *visitor) Limits() *visitorLimits {
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
}
return limits
}
@ -327,5 +331,6 @@ func defaultVisitorLimits(conf *Config) *visitorLimits {
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
}
}

View file

@ -52,6 +52,7 @@ const (
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
@ -109,26 +110,26 @@ const (
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -232,20 +233,20 @@ const (
`
insertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier
`
selectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
@ -670,11 +671,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -714,6 +715,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty
}
}
@ -994,7 +996,7 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error {
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
return err
}
return nil
@ -1051,11 +1053,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string
var stripePriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -1072,6 +1074,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty
}, nil
}

View file

@ -3,6 +3,7 @@ package user
import (
"database/sql"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/util"
"path/filepath"
"strings"
"testing"
@ -583,21 +584,21 @@ func TestManager_ChangeSettings(t *testing.T) {
require.Nil(t, err)
require.Nil(t, u.Prefs.Subscriptions)
require.Nil(t, u.Prefs.Notification)
require.Equal(t, "", u.Prefs.Language)
require.Nil(t, u.Prefs.Language)
// Save with new settings
u.Prefs = &Prefs{
Language: "de",
Language: util.String("de"),
Notification: &NotificationPrefs{
Sound: "ding",
MinPriority: 2,
Sound: util.String("ding"),
MinPriority: util.Int(2),
},
Subscriptions: []*Subscription{
{
ID: "someID",
BaseURL: "https://ntfy.sh",
Topic: "mytopic",
DisplayName: "My Topic",
DisplayName: util.String("My Topic"),
},
},
}
@ -606,14 +607,14 @@ func TestManager_ChangeSettings(t *testing.T) {
// Read again
u, err = a.User("ben")
require.Nil(t, err)
require.Equal(t, "de", u.Prefs.Language)
require.Equal(t, "ding", u.Prefs.Notification.Sound)
require.Equal(t, 2, u.Prefs.Notification.MinPriority)
require.Equal(t, 0, u.Prefs.Notification.DeleteAfter)
require.Equal(t, util.String("de"), u.Prefs.Language)
require.Equal(t, util.String("ding"), u.Prefs.Notification.Sound)
require.Equal(t, util.Int(2), u.Prefs.Notification.MinPriority)
require.Nil(t, u.Prefs.Notification.DeleteAfter)
require.Equal(t, "someID", u.Prefs.Subscriptions[0].ID)
require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL)
require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic)
require.Equal(t, "My Topic", u.Prefs.Subscriptions[0].DisplayName)
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
}
func TestSqliteCache_Migration_From1(t *testing.T) {

View file

@ -50,17 +50,18 @@ type Prefs struct {
// Tier represents a user's account type, including its account limits
type Tier struct {
ID string
Code string
Name string
MessagesLimit int64
MessagesExpiryDuration time.Duration
EmailsLimit int64
ReservationsLimit int64
AttachmentFileSizeLimit int64
AttachmentTotalSizeLimit int64
AttachmentExpiryDuration time.Duration
StripePriceID string
ID string // Tier identifier (ti_...)
Code string // Code of the tier
Name string // Name of the tier
MessagesLimit int64 // Daily message limit
MessagesExpiryDuration time.Duration // Cache duration for messages
EmailsLimit int64 // Daily email limit
ReservationsLimit int64 // Number of topic reservations allowed by user
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user
StripePriceID string // Price ID for paid tiers (price_...)
}
// Subscription represents a user's topic subscription

View file

@ -336,3 +336,13 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
}
return nil, err
}
// String turns a string into a pointer of a string
func String(v string) *string {
return &v
}
// Int turns a string into a pointer of an int
func Int(v int) *int {
return &v
}