mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-21 19:03:26 +01:00
Add bandwidth limit to tier; fix display name sync issues
This commit is contained in:
parent
1771cb3fdb
commit
236254d907
13 changed files with 119 additions and 51 deletions
|
@ -285,7 +285,7 @@ func execServe(c *cli.Context) error {
|
|||
conf.TotalTopicLimit = totalTopicLimit
|
||||
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
|
||||
conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit
|
||||
conf.VisitorAttachmentDailyBandwidthLimit = int(visitorAttachmentDailyBandwidthLimit)
|
||||
conf.VisitorAttachmentDailyBandwidthLimit = visitorAttachmentDailyBandwidthLimit
|
||||
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
|
||||
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
|
||||
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
|
||||
|
|
|
@ -101,7 +101,7 @@ type Config struct {
|
|||
TotalAttachmentSizeLimit int64
|
||||
VisitorSubscriptionLimit int
|
||||
VisitorAttachmentTotalSizeLimit int64
|
||||
VisitorAttachmentDailyBandwidthLimit int
|
||||
VisitorAttachmentDailyBandwidthLimit int64
|
||||
VisitorRequestLimitBurst int
|
||||
VisitorRequestLimitReplenish time.Duration
|
||||
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||
|
|
|
@ -40,7 +40,6 @@ TODO
|
|||
|
||||
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- HIGH Rate limiting: Bandwidth limit must be in tier + TESTS
|
||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
|
@ -866,7 +865,6 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
|||
util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
|
||||
util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
|
||||
}
|
||||
fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
|
||||
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
|
||||
if err == util.ErrLimitReached {
|
||||
return errHTTPEntityTooLargeAttachment
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
const (
|
||||
subscriptionIDLength = 16
|
||||
subscriptionIDPrefix = "su_"
|
||||
syncTopicAccountSyncEvent = "sync"
|
||||
)
|
||||
|
||||
|
@ -55,6 +56,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: limits.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: int64(limits.AttachmentExpiryDuration.Seconds()),
|
||||
AttachmentBandwidth: limits.AttachmentBandwidthLimit,
|
||||
},
|
||||
Stats: &apiAccountStats{
|
||||
Messages: stats.Messages,
|
||||
|
@ -249,7 +251,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
|
|||
}
|
||||
}
|
||||
if newSubscription.ID == "" {
|
||||
newSubscription.ID = util.RandomString(subscriptionIDLength)
|
||||
newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription)
|
||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
||||
return err
|
||||
|
|
|
@ -153,9 +153,9 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
|||
require.Equal(t, 200, rr.Code)
|
||||
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
|
||||
require.Equal(t, "de", account.Language)
|
||||
require.Equal(t, 86400, account.Notification.DeleteAfter)
|
||||
require.Equal(t, "juntos", account.Notification.Sound)
|
||||
require.Equal(t, 0, account.Notification.MinPriority) // Not set
|
||||
require.Equal(t, util.Int(86400), account.Notification.DeleteAfter)
|
||||
require.Equal(t, util.String("juntos"), account.Notification.Sound)
|
||||
require.Nil(t, account.Notification.MinPriority) // Not set
|
||||
}
|
||||
|
||||
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
||||
|
@ -176,7 +176,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
|||
require.NotEmpty(t, account.Subscriptions[0].ID)
|
||||
require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL)
|
||||
require.Equal(t, "def", account.Subscriptions[0].Topic)
|
||||
require.Equal(t, "", account.Subscriptions[0].DisplayName)
|
||||
require.Nil(t, account.Subscriptions[0].DisplayName)
|
||||
|
||||
subscriptionID := account.Subscriptions[0].ID
|
||||
rr = request(t, s, "PATCH", "/v1/account/subscription/"+subscriptionID, `{"display_name": "ding dong"}`, map[string]string{
|
||||
|
@ -193,7 +193,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
|||
require.Equal(t, subscriptionID, account.Subscriptions[0].ID)
|
||||
require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL)
|
||||
require.Equal(t, "def", account.Subscriptions[0].Topic)
|
||||
require.Equal(t, "ding dong", account.Subscriptions[0].DisplayName)
|
||||
require.Equal(t, util.String("ding dong"), account.Subscriptions[0].DisplayName)
|
||||
|
||||
rr = request(t, s, "DELETE", "/v1/account/subscription/"+subscriptionID, "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -402,6 +402,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
|||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
AttachmentExpiryDuration: 10800 * time.Second,
|
||||
AttachmentBandwidthLimit: 21474836480,
|
||||
}))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
|
@ -442,6 +443,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
|||
require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize)
|
||||
require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize)
|
||||
require.Equal(t, int64(10800), account.Limits.AttachmentExpiryDuration)
|
||||
require.Equal(t, int64(21474836480), account.Limits.AttachmentBandwidth)
|
||||
require.Equal(t, 2, len(account.Reservations))
|
||||
require.Equal(t, "another", account.Reservations[0].Topic)
|
||||
require.Equal(t, "write-only", account.Reservations[0].Everyone)
|
||||
|
|
|
@ -265,6 +265,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
AttachmentBandwidthLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
|
@ -275,6 +276,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
AttachmentBandwidthLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
|
|
@ -1368,6 +1368,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: sevenDays, // 7 days
|
||||
AttachmentBandwidthLimit: 100000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
@ -1376,6 +1377,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
response := request(t, s, "PUT", "/mytopic", content, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
msg := toMessage(t, response.Body.String())
|
||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||
require.True(t, msg.Attachment.Expires > time.Now().Add(sevenDays-30*time.Second).Unix())
|
||||
|
@ -1396,6 +1398,46 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
|
||||
func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
||||
content := util.RandomString(5000) // > 4096
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 10,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentBandwidthLimit: 14000, // < 3x5000 bytes -> enough for one upload, one download
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish and make sure we can retrieve it
|
||||
rr := request(t, s, "PUT", "/mytopic", content, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
msg := toMessage(t, rr.Body.String())
|
||||
|
||||
// Retrieve it (first time succeeds)
|
||||
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, content, rr.Body.String())
|
||||
|
||||
// Retrieve it AGAIN (fails, due to bandwidth limit)
|
||||
rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 429, rr.Code)
|
||||
}
|
||||
|
||||
func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
||||
smallFile := util.RandomString(20_000)
|
||||
largeFile := util.RandomString(50_000)
|
||||
|
@ -1412,6 +1454,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
|||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: 30 * time.Second,
|
||||
AttachmentBandwidthLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
|
|
@ -246,7 +246,7 @@ type apiAccountTier struct {
|
|||
}
|
||||
|
||||
type apiAccountLimits struct {
|
||||
Basis string `json:"basis,omitempty"` // "ip", "role" or "tier"
|
||||
Basis string `json:"basis,omitempty"` // "ip" or "tier"
|
||||
Messages int64 `json:"messages"`
|
||||
MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
|
||||
Emails int64 `json:"emails"`
|
||||
|
@ -254,6 +254,7 @@ type apiAccountLimits struct {
|
|||
AttachmentTotalSize int64 `json:"attachment_total_size"`
|
||||
AttachmentFileSize int64 `json:"attachment_file_size"`
|
||||
AttachmentExpiryDuration int64 `json:"attachment_expiry_duration"`
|
||||
AttachmentBandwidth int64 `json:"attachment_bandwidth"`
|
||||
}
|
||||
|
||||
type apiAccountStats struct {
|
||||
|
|
|
@ -31,9 +31,9 @@ var (
|
|||
type visitor struct {
|
||||
config *Config
|
||||
messageCache *messageCache
|
||||
userManager *user.Manager // May be nil!
|
||||
ip netip.Addr
|
||||
user *user.User
|
||||
userManager *user.Manager // May be nil
|
||||
ip netip.Addr // Visitor IP address
|
||||
user *user.User // Only set if authenticated user, otherwise nil
|
||||
messages int64 // Number of messages sent, reset every day
|
||||
emails int64 // Number of emails sent, reset every day
|
||||
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
|
||||
|
@ -61,6 +61,7 @@ type visitorLimits struct {
|
|||
AttachmentTotalSizeLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
AttachmentExpiryDuration time.Duration
|
||||
AttachmentBandwidthLimit int64
|
||||
}
|
||||
|
||||
type visitorStats struct {
|
||||
|
@ -84,7 +85,7 @@ const (
|
|||
)
|
||||
|
||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
var messagesLimiter util.Limiter
|
||||
var messagesLimiter, attachmentBandwidthLimiter util.Limiter
|
||||
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
||||
var messages, emails int64
|
||||
if user != nil {
|
||||
|
@ -97,9 +98,11 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
|||
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
|
||||
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
|
||||
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
|
||||
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
||||
} else {
|
||||
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
|
||||
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
|
||||
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
||||
}
|
||||
return &visitor{
|
||||
config: conf,
|
||||
|
@ -113,7 +116,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
|||
messagesLimiter: messagesLimiter, // May be nil
|
||||
emailsLimiter: emailsLimiter,
|
||||
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
bandwidthLimiter: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
|
||||
bandwidthLimiter: attachmentBandwidthLimiter,
|
||||
accountLimiter: accountLimiter, // May be nil
|
||||
firebase: time.Unix(0, 0),
|
||||
seen: time.Now(),
|
||||
|
@ -259,6 +262,7 @@ func (v *visitor) Limits() *visitorLimits {
|
|||
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
|
||||
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
|
||||
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
|
||||
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
|
||||
}
|
||||
return limits
|
||||
}
|
||||
|
@ -327,5 +331,6 @@ func defaultVisitorLimits(conf *Config) *visitorLimits {
|
|||
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
|
||||
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ const (
|
|||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
|
@ -109,26 +110,26 @@ const (
|
|||
`
|
||||
|
||||
selectUserByIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
selectUserByNameQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
selectUserByTokenQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
JOIN user_token t on u.id = t.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE t.token = ? AND t.expires >= ?
|
||||
`
|
||||
selectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
|
@ -232,20 +233,20 @@ const (
|
|||
`
|
||||
|
||||
insertTierQuery = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
selectTiersQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
|
||||
FROM tier
|
||||
`
|
||||
selectTierByCodeQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
|
||||
FROM tier
|
||||
WHERE code = ?
|
||||
`
|
||||
selectTierByPriceIDQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
|
||||
FROM tier
|
||||
WHERE stripe_price_id = ?
|
||||
`
|
||||
|
@ -670,11 +671,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
var id, username, hash, role, prefs, syncTopic string
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
|
||||
var messages, emails int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -714,6 +715,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripePriceID: stripePriceID.String, // May be empty
|
||||
}
|
||||
}
|
||||
|
@ -994,7 +996,7 @@ func (a *Manager) DefaultAccess() Permission {
|
|||
// CreateTier creates a new tier in the database
|
||||
func (a *Manager) CreateTier(tier *Tier) error {
|
||||
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
|
||||
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -1051,11 +1053,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
|||
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
||||
var id, code, name string
|
||||
var stripePriceID sql.NullString
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrTierNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
||||
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -1072,6 +1074,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
|||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripePriceID: stripePriceID.String, // May be empty
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package user
|
|||
import (
|
||||
"database/sql"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/util"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -583,21 +584,21 @@ func TestManager_ChangeSettings(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
require.Nil(t, u.Prefs.Subscriptions)
|
||||
require.Nil(t, u.Prefs.Notification)
|
||||
require.Equal(t, "", u.Prefs.Language)
|
||||
require.Nil(t, u.Prefs.Language)
|
||||
|
||||
// Save with new settings
|
||||
u.Prefs = &Prefs{
|
||||
Language: "de",
|
||||
Language: util.String("de"),
|
||||
Notification: &NotificationPrefs{
|
||||
Sound: "ding",
|
||||
MinPriority: 2,
|
||||
Sound: util.String("ding"),
|
||||
MinPriority: util.Int(2),
|
||||
},
|
||||
Subscriptions: []*Subscription{
|
||||
{
|
||||
ID: "someID",
|
||||
BaseURL: "https://ntfy.sh",
|
||||
Topic: "mytopic",
|
||||
DisplayName: "My Topic",
|
||||
DisplayName: util.String("My Topic"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -606,14 +607,14 @@ func TestManager_ChangeSettings(t *testing.T) {
|
|||
// Read again
|
||||
u, err = a.User("ben")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "de", u.Prefs.Language)
|
||||
require.Equal(t, "ding", u.Prefs.Notification.Sound)
|
||||
require.Equal(t, 2, u.Prefs.Notification.MinPriority)
|
||||
require.Equal(t, 0, u.Prefs.Notification.DeleteAfter)
|
||||
require.Equal(t, util.String("de"), u.Prefs.Language)
|
||||
require.Equal(t, util.String("ding"), u.Prefs.Notification.Sound)
|
||||
require.Equal(t, util.Int(2), u.Prefs.Notification.MinPriority)
|
||||
require.Nil(t, u.Prefs.Notification.DeleteAfter)
|
||||
require.Equal(t, "someID", u.Prefs.Subscriptions[0].ID)
|
||||
require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL)
|
||||
require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic)
|
||||
require.Equal(t, "My Topic", u.Prefs.Subscriptions[0].DisplayName)
|
||||
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||
|
|
|
@ -50,17 +50,18 @@ type Prefs struct {
|
|||
|
||||
// Tier represents a user's account type, including its account limits
|
||||
type Tier struct {
|
||||
ID string
|
||||
Code string
|
||||
Name string
|
||||
MessagesLimit int64
|
||||
MessagesExpiryDuration time.Duration
|
||||
EmailsLimit int64
|
||||
ReservationsLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentExpiryDuration time.Duration
|
||||
StripePriceID string
|
||||
ID string // Tier identifier (ti_...)
|
||||
Code string // Code of the tier
|
||||
Name string // Name of the tier
|
||||
MessagesLimit int64 // Daily message limit
|
||||
MessagesExpiryDuration time.Duration // Cache duration for messages
|
||||
EmailsLimit int64 // Daily email limit
|
||||
ReservationsLimit int64 // Number of topic reservations allowed by user
|
||||
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
|
||||
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
|
||||
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
|
||||
AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user
|
||||
StripePriceID string // Price ID for paid tiers (price_...)
|
||||
}
|
||||
|
||||
// Subscription represents a user's topic subscription
|
||||
|
|
10
util/util.go
10
util/util.go
|
@ -336,3 +336,13 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
|
|||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// String turns a string into a pointer of a string
|
||||
func String(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int turns a string into a pointer of an int
|
||||
func Int(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue